Beispiel #1
0
    def __call__(self, graph_name, module, input_args):
        torch_graph_handler = TorchGraphHandler()
        raw_graph = torch_graph_handler.build_torch_graph(
            graph_name, module, input_args)
        self._nndct_graph = Graph(graph_name=raw_graph.name)
        node_convertor = NodeConvertor()
        op_creator = OpCreator()
        for raw_node in raw_graph.nodes:
            nndct_node = node_convertor(self, raw_graph, raw_node)
            if nndct_node:
                self._nndct_graph.add_node(nndct_node)
                nndct_node.op = op_creator(self, raw_graph, raw_node)

        for ret_value_name in raw_graph.ret_values().keys():
            end_tensor = self._nndct_graph.tensor(
                get_full_name(self._nndct_graph.name, ret_value_name))
            self._nndct_graph.add_end_tensor(end_tensor)

        self._convert_blob_tensor_type()
        self._nndct_graph.connect_nodes()
        self._load_data(module)
        if NndctOption.nndct_parse_debug.value >= 2:
            NndctDebugLogger.write(f"nndct raw graph:\n{self._nndct_graph}")
        # print(f"nndct graph:{self._nndct_graph}")
        return self._nndct_graph
Beispiel #2
0
def merge_multi_subgraphs(graphs: List[Graph],
                          graph_name="Nndctgraph") -> Graph:
    top_graph = Graph(graph_name)
    for graph in graphs:
        for node in graph.nodes:
            top_graph.add_node(node)
    return top_graph
Beispiel #3
0
    def __call__(self, graph: Graph, fuse_conv_bn: bool = True):

        commander = OptimizeCommander(graph=graph)
        commander.DecoupleSharedParamsInConv()
        if fuse_conv_bn:
            #commander.DecoupleSharedParamsInConv()
            commander.FuseBnToConv()
            commander.ConvertBNParams()

            if NndctOption.nndct_equalization.value:
                NndctScreenLogger().info(f"=>Doing weights equalization...")
                #print("before equalization")
                #pre_cov = self._cal_channel_range_coeffience(graph)
                graph = commander.equalize_weights_cross_conv_layers()
                #print("after equalization")
                #self._cal_channel_range_coeffience(graph, pre_cov)

            # if NndctOption.nndct_wes.value:
            #   NndctScreenLogger().info(f"=>Doing weights equalizing shift...")
            #   graph = commander.weights_equalizing_shift()

        if NndctOption.nndct_partition_mode.value > 0:
            self._tag_quant_nodes_v2(graph)
        else:
            self._tag_quant_nodes(graph)
        graph.remove_node_by_types([NNDCT_OP.DROPOUT])
        # print(f"quant_state:")
        # for node in graph.nodes:
        #   print(f"{node.name}:{node.in_quant_part}")
        return graph
Beispiel #4
0
 def __init__(self, nndct_graph):
     self._dev_graph = Graph(graph_name=nndct_graph.name)
     self._dev_graph.clone_from(nndct_graph)
     self._evalute_func_map = {
         NNDCT_OP.SHAPE: Evaluator.shape,
         NNDCT_OP.CAST: Evaluator.cast,
         NNDCT_OP.INT: Evaluator.int,
         NNDCT_OP.SCALAR_MUL: Evaluator.mul,
         NNDCT_OP.TENSOR: Evaluator.tensor,
         NNDCT_OP.FLOOR: Evaluator.floor,
         NNDCT_OP.DIV: Evaluator.elemwise_div,
         NNDCT_OP.FLOOR_DIV: Evaluator.floor_div,
         NNDCT_OP.ADD: Evaluator.add,
         NNDCT_OP.SCALAR_ADD: Evaluator.add
     }
Beispiel #5
0
  def clone_quant_module(cls, quant_module):
    quantizer = GLOBAL_MAP.get_ele(NNDCT_KEYS.QUANTIZER)
  
    if _is_module_hooked(quant_module):
      cls.detach_node_from_module(quant_module)
      cls.hook_module_with_quantizer(quant_module, None)
      new_quant_module = copy.deepcopy(quant_module)
      cls.hook_module_with_node(quant_module, quantizer.graph)
      cls.hook_module_with_quantizer(quant_module, quantizer)
      new_graph = Graph(graph_name=quantizer.graph.name)
      new_graph.clone_from(quantizer.graph)
      cls.hook_module_with_node(new_quant_module, new_graph)
      cls.hook_module_with_quantizer(new_quant_module, quantizer)
    else:
      new_quant_module = copy.deepcopy(quant_module)

    return new_quant_module
Beispiel #6
0
def quant_optimize(graph: Graph):
  def _execute_optimize(block):
    optimizer = QuantOptimizer()
    graph = optimizer(block)
    return graph
    
  for block in graph.block_subgraphs():
      quant_optimize(block)
  graph = _execute_optimize(graph)
  return graph
Beispiel #7
0
def collect_all_blocks(graph: Graph,
                       blocks: Optional[List[Graph]] = None) -> List[Graph]:
    if blocks is None:
        blocks: List[Graph] = []

    for subgraph in graph.block_subgraphs():
        blocks.append(subgraph)
        if list(subgraph.block_subgraphs()):
            collect_all_blocks(subgraph, blocks)

    return blocks
Beispiel #8
0
    def __call__(self, graph: Graph, fuse_conv_bn: bool = True):

        commander = OptimizeCommander(graph=graph)
        if fuse_conv_bn:
            commander.FuseBnToConv()

            if NndctOption.nndct_equalization.value:
                NndctScreenLogger().info(f"=>Doing weights equalization...")
                #print("before equalization")
                #pre_cov = self._cal_channel_range_coeffience(graph)
                commander.equalize_weights_cross_conv_layers()
                #print("after equalization")
                #self._cal_channel_range_coeffience(graph, pre_cov)

        self._tag_quant_nodes(graph)
        graph.remove_node_by_types([NNDCT_OP.DROPOUT, NNDCT_OP.DEQUANT_STUB])
        # print(f"quant_state:")
        # for node in graph.nodes:
        #   print(f"{node.name}:{node.in_quant_part}")
        return graph
Beispiel #9
0
 def verify_xmodel(compile_graph: Graph, xgraph: XGraph):
     """verify the xmodel by nndct node shape"""
     sorted_nodes = compile_graph.top_sort_nodeset(list(
         compile_graph.nodes))
     for node in sorted_nodes:
         if node.out_tensors[0].ndim and node.out_tensors[0].ndim > 1:
             xop_shape = xgraph.get_op_output_shape(node.name)
             if tuple(xop_shape) != tuple(node.out_tensors[0].shape):
                 NndctScreenLogger().error(
                     f"output shape of {node.name}({node.out_tensors[0].shape}) is different from the output shape of XIR ({xop_shape})"
                 )
Beispiel #10
0
    def _convert_graph(self, raw_graph):
        nndct_graph = Graph(graph_name=raw_graph.name)
        self.cur_graph = nndct_graph
        graph_input = self._convert_node(raw_graph.head_node)
        graph_return = self._convert_node(raw_graph.return_node)
        top_block = Block(nndct_graph, None, graph_input, graph_return)
        self.cur_block = top_block
        nndct_graph.set_top_block(top_block)

        self._convert_params(raw_graph)

        pbar = tqdm(list(raw_graph.nodes), bar_format="{bar:50}{r_bar}")
        for raw_node in pbar:
            pbar.set_postfix_str(
                f"OpInfo: name = {raw_node.name}, type = {raw_node.kind}")
            pbar.update()
            nndct_node = self._convert_node(raw_node)
            if not nndct_node.in_node_list():
                nndct_graph.append_node(nndct_node)

            for sub_block in raw_node.blocks:
                cur_block = self.cur_block
                self.cur_block = None
                node_block = self._convert_block(nndct_graph, nndct_node,
                                                 sub_block)
                self.cur_block = cur_block
                nndct_node.add_block(node_block)
            self._bind_free_params(nndct_node)

        # for node in nndct_graph.nodes:
        #   print(node.name, node.in_nodes, node.out_nodes, node.topo_position)
        return nndct_graph
Beispiel #11
0
def update_nndct_blob_data(module: torch.nn.Module,
                           graph: Graph,
                           time_step: Optional[int] = None) -> NoReturn:
    ModuleHooker.update_blobs_once(module, graph, time_step)
    permute_nodes = graph.find_nodes_by_types([NNDCT_OP.PERMUTE])
    for node in permute_nodes:
        in_data = node.in_tensors[0].data
        if in_data is not None:
            data = permute_data(in_data,
                                node.node_attr(node.op.AttrName.ORDER))
            node.out_tensors[0].from_ndarray(data)
        else:
            NndctScreenLogger().warning(f"{node.__repr__()} has no data.")
    def _write_forward(self, f, graph: Graph):
        indent_str = 4 * " "
        f.write('\n' + indent_str + "def forward(self, *args):\n")
        indent_str += indent_str
        for node in graph.nodes:
            forward_str, output_str = self._get_forward_str(node)
            f.write(indent_str + forward_str + '\n')

        return_str = indent_str + 'return '
        for i, end_tensor in enumerate(graph.get_end_tensors()):
            if i > 0:
                return_str = ','.join(
                    [return_str, self.tensor_output_map[end_tensor.name]])
            else:
                return_str += self.tensor_output_map[end_tensor.name]

        f.write(return_str + '\n')
Beispiel #13
0
def unknown_op_type_check(graph: Graph):
    unkown_ops = set()
    custom_ops = set()
    for node in graph.all_nodes():
        if isinstance(node.op, TorchUnknownOperation):
            unkown_ops.add(node.op.type)
        elif node.has_custom_op():
            custom_ops.add(node.op.type)
    for op in custom_ops:
        NndctScreenLogger().warning(
            f"The quantizer recognize new op `{op}` as a float operator by default."
        )


#   if custom_ops:
#     NndctScreenLogger().info(f"You can make these new ops quantizable by add them to custom_quant_ops, \
# e.g. quantizer= torch_quantizer(..., custom_quant_ops=['{list(custom_ops)[0]}',...])")

    NndctScreenLogger().check(f"Unsupported Ops: {unkown_ops}",
                              len(unkown_ops) == 0)
Beispiel #14
0
    def _convert_graph(self, raw_graph, device_type):
        nndct_graph = Graph(graph_name=raw_graph.name)
        node_convertor = NodeConvertor()
        op_creator = OpCreator(device_type)
        pbar = tqdm(list(raw_graph.nodes), bar_format="{bar:50}{r_bar}")
        #for raw_node in raw_graph.nodes:
        for raw_node in pbar:
            pbar.set_postfix_str(
                f"OpInfo: name = {raw_node.name}, type = {raw_node.kind}")
            pbar.update()
            nndct_node = node_convertor(self,
                                        raw_node,
                                        node_scope=nndct_graph.name)
            if nndct_node:
                nndct_graph.add_node(nndct_node)
                nndct_node.op = op_creator(self, nndct_node)
                for i, block in enumerate(raw_node.blocks):
                    nndct_block = self._convert_graph(block, device_type)
                    nndct_node.add_block(nndct_block)

                self._bind_free_params(nndct_node)

        def _construct_ret_struct(values, return_struct):
            for value in values:
                if isinstance(value, list):
                    inner_list = []
                    _construct_ret_struct(value, inner_list)
                    return_struct.append(inner_list)
                else:
                    end_tensor = self.get_nndct_value(value)
                    assert end_tensor is not None
                    return_struct.append(end_tensor)
                    nndct_graph.add_end_tensor(end_tensor)

        nndct_graph.return_struct = []
        _construct_ret_struct(raw_graph.ret_values().values(),
                              nndct_graph.return_struct)
        nndct_graph.connect_nodes()

        return nndct_graph
Beispiel #15
0
    def _collect_reuse_output(self, graph: Graph):
        def dfs(node, visited):
            visited.append(node)
            if len(node.out_tensors) == 1 \
            and graph.parents(node) \
            and len(graph.parents(node)[0].out_nodes) == 1 \
            and graph.parents(node)[0].out_tensors[0] not in graph.end_tensors:
                self._reuse_node_output_map[node.name] = graph.parents(
                    node)[0].out_tensors[0].name

            for cn in graph.children(node):
                if cn not in visited:
                    dfs(cn, visited)

        if list(graph.block_subgraphs()):
            return

        visited = []
        input_nodes = [
            node for node in graph.nodes if node.op.type == NNDCT_OP.INPUT
        ]
        for node in input_nodes:
            dfs(node, visited)
Beispiel #16
0
    def partition_by_quant_part(self) -> List[List[Graph]]:
        if not any([
                node.op.type == NNDCT_OP.QUANT_STUB
                for node in self._dev_graph.nodes
        ]):
            return [[self._dev_graph]]

        id2nodes = defaultdict(set)

        def collect_node_set(node, set_id, visited=None):

            if visited is None:
                visited = []

            if node.op.type == NNDCT_OP.RETURN:
                return

            if not hasattr(node, "set_id"):
                node.set_id = set_id

            id2nodes[set_id].add(node)
            visited.append(node)

            for cn in self._dev_graph.children(node):
                if cn not in visited and cn.in_quant_part:
                    collect_node_set(cn, set_id, visited)

        def get_set_id_from_nodeset(nodeset):
            return min([node.set_id for node in nodeset])

        def partition_check(quant_graphs, node_graph_id):
            for node_name, graph_id in node_graph_id.items():
                if len(graph_id) > 1:
                    NndctScreenLogger().error(
                        f"The subgraph{graph_id} hold {node_name} at the same time."
                    )
            for node in self._dev_graph.nodes:
                if node.op.type == NNDCT_OP.RETURN:
                    continue
                if node.in_quant_part and all(
                    [node not in graph for graph in quant_graphs]):
                    raise RuntimeError(
                        f"Please check graph partition: the quant node '{node.name}' should be in quant graph."
                    )
                elif not node.in_quant_part and any(
                    [node in graph for graph in quant_graphs]):
                    raise RuntimeError(
                        f"Please check graph partition: the non-quant node '{node.name}' included in quant graph."
                    )

        set_id = 0
        for node in self._dev_graph.nodes:
            visited = []
            if node.op.type == NNDCT_OP.QUANT_STUB or (not node.in_nodes
                                                       and node.in_quant_part):
                collect_node_set(node, set_id, visited)
                set_id += 1

        merged_id2nodes = defaultdict(set)
        for _, nodeset in id2nodes.items():
            id = get_set_id_from_nodeset(nodeset)
            merged_id2nodes[id].update(nodeset)

        quant_dev_graph = []
        node_graph_id = defaultdict(list)
        for graph_id, nodes in merged_id2nodes.items():
            for node in nodes:
                node_graph_id[node.name].append(graph_id)
            subgraph = Graph.create_subgraph_from_nodeset(
                self._dev_graph, nodes, f"{self._dev_graph.name}_{graph_id}")
            quant_dev_graph.append(subgraph)

        partition_check(quant_dev_graph, node_graph_id)
        if NndctOption.nndct_dump_no_quant_part.value:
            return [quant_dev_graph, [self._dev_graph]]
        else:
            return [quant_dev_graph]
Beispiel #17
0
    def partition_by_quant_part(self) -> List[Graph]:
        if not any([
                node.op.type == NNDCT_OP.QUANT_STUB
                for node in self._dev_graph.nodes
        ]):
            return [self._dev_graph]

        id2nodes = defaultdict(set)

        def collect_node_set(node, set_id, visited=None):
            # if not node.in_quant_part:
            #   return
            if not visited:
                visited = []

            if not hasattr(node, "set_id"):
                node.set_id = set_id

            id2nodes[set_id].add(node)
            visited.append(node)

            for cn in self._dev_graph.children(node):
                if cn not in visited and cn.in_quant_part:
                    collect_node_set(cn, set_id, visited)

        def get_set_id_from_nodeset(nodeset):
            return min([node.set_id for node in nodeset])

        def partition_check(quant_graphs):
            for node in self._dev_graph.nodes:
                if node.in_quant_part and all(
                    [node not in graph for graph in quant_graphs]):
                    raise RuntimeError(
                        f"Please check graph partition: the quant node '{node.name}' should be in quant graph."
                    )
                elif not node.in_quant_part and any(
                    [node in graph for graph in quant_graphs]):
                    raise RuntimeError(
                        f"Please check graph partition: the non-quant node '{node.name}' included in quant graph."
                    )

        set_id = 0
        for node in self._dev_graph.nodes:
            if node.op.type == NNDCT_OP.QUANT_STUB or (not node.in_nodes
                                                       and node.in_quant_part):
                collect_node_set(node, set_id)
                set_id += 1

        merged_id2nodes = defaultdict(set)
        for nodeset in id2nodes.values():
            id = get_set_id_from_nodeset(nodeset)
            merged_id2nodes[id].update(nodeset)

        quant_dev_graph = []
        for graph_id, nodes in merged_id2nodes.items():
            subgraph = Graph.create_subgraph_from_nodeset(
                self._dev_graph, nodes, f"{self._dev_graph.name}_{graph_id}")
            quant_dev_graph.append(subgraph)

        partition_check(quant_dev_graph)
        return quant_dev_graph
Beispiel #18
0
class TorchParser(object):
    def __call__(self, graph_name, module, input_args):
        torch_graph_handler = TorchGraphHandler()
        raw_graph = torch_graph_handler.build_torch_graph(
            graph_name, module, input_args)
        self._nndct_graph = Graph(graph_name=raw_graph.name)
        node_convertor = NodeConvertor()
        op_creator = OpCreator()
        for raw_node in raw_graph.nodes:
            nndct_node = node_convertor(self, raw_graph, raw_node)
            if nndct_node:
                self._nndct_graph.add_node(nndct_node)
                nndct_node.op = op_creator(self, raw_graph, raw_node)

        for ret_value_name in raw_graph.ret_values().keys():
            end_tensor = self._nndct_graph.tensor(
                get_full_name(self._nndct_graph.name, ret_value_name))
            self._nndct_graph.add_end_tensor(end_tensor)

        self._convert_blob_tensor_type()
        self._nndct_graph.connect_nodes()
        self._load_data(module)
        if NndctOption.nndct_parse_debug.value >= 2:
            NndctDebugLogger.write(f"nndct raw graph:\n{self._nndct_graph}")
        # print(f"nndct graph:{self._nndct_graph}")
        return self._nndct_graph

    def _convert_blob_tensor_type(self):
        r"""convert torch tensor info to nndct tensor info"""
        for blob_tensor in self._nndct_graph.tensors.values():
            tensor_util.convert_blob_tensor_format(
                blob_tensor, tensor_util.FrameworkType.TORCH,
                tensor_util.FrameworkType.NNDCT)
            blob_tensor.dtype = convert_dtype(blob_tensor.dtype)

    def _load_data(self, module):
        for node in self._nndct_graph.nodes:
            if node.op.type in [NNDCT_OP.BASIC_LSTM, NNDCT_OP.BASIC_GRU]:
                for nndct_param, param_tensors in node.op.params.items():
                    for tensor in param_tensors:
                        data = module.state_dict()[get_short_name(
                            tensor.name)].cpu().numpy()
                        tensor.from_ndarray(data)
                        tensor = tensor_util.convert_parameter_tensor_format(
                            tensor, FrameworkType.TORCH, FrameworkType.NNDCT)
                #combine bias_ih and bias_hh item

                if node.op.type == NNDCT_OP.BASIC_LSTM:
                    for bias_term in [
                            node.op.ParamName.BIAS,
                            node.op.ParamName.BIAS_REVERSE
                    ]:
                        if bias_term in node.op.params and len(
                                node.op.params[bias_term]) > 0:
                            if len(node.op.params[bias_term]) % 2 != 0:
                                raise RuntimeError(
                                    "The num of bias should be even")
                            i = 0
                            bias_list = []
                            while i != len(node.op.params[bias_term]):
                                bias_ih = node.op.params[bias_term][i]
                                bias_hh = node.op.params[bias_term][i + 1]
                                tensor_name = f"bias_{i//2}" if bias_term == node.op.ParamName.BIAS else f"bias_{i//2}_reverse"
                                bias = Tensor(name=get_full_name(
                                    self._nndct_graph.name, tensor_name),
                                              data=bias_ih.data + bias_hh.data)
                                bias_list.append(bias)
                                i = i + 2
                            node.op.set_param(bias_term, bias_list)
            elif node.op.type == NNDCT_OP.CONVTRANSPOSE2D:
                for param_name, tensor in node.op.params.items():
                    data = module.state_dict()[get_short_name(
                        tensor.name)].cpu().numpy()
                    if param_name == node.op.ParamName.WEIGHTS:
                        data = np.copy(data).transpose(1, 0, 2, 3)
                        data = np.ascontiguousarray(data)

                    tensor.from_ndarray(data)
                    tensor = tensor_util.convert_parameter_tensor_format(
                        tensor, FrameworkType.TORCH, FrameworkType.NNDCT)

            elif node.op.type == NNDCT_OP.DEPTHWISE_CONV2D:
                for param_name, tensor in node.op.params.items():
                    data = module.state_dict()[get_short_name(
                        tensor.name)].cpu().numpy()
                    if param_name == node.op.ParamName.WEIGHTS:
                        in_channels = node.node_config("in_channels")
                        out_channels = node.node_config("out_channels")
                        kernel_size = node.node_config("kernel_size")
                        channel_mutiplier = int(out_channels / in_channels)
                        data = np.copy(data).reshape(
                            (channel_mutiplier, in_channels, *kernel_size))

                    tensor.from_ndarray(data)
                    tensor = tensor_util.convert_parameter_tensor_format(
                        tensor, FrameworkType.TORCH, FrameworkType.NNDCT)
            else:
                for param_name, tensor in node.op.params.items():
                    data = module.state_dict()[get_short_name(
                        tensor.name)].cpu().numpy()
                    tensor.from_ndarray(data)
                    tensor = tensor_util.convert_parameter_tensor_format(
                        tensor, FrameworkType.TORCH, FrameworkType.NNDCT)

    def get_blob_tensor_by_name(self, name):
        name = get_full_name(self._nndct_graph.name, name)
        return self._nndct_graph.tensor(name)

    def get_nndct_value(self, torch_value):
        r"""
    three simple types of value : nndct tensor/plain value/None
    """
        def _get_simple_value(value):
            if value.is_none():
                return None
            elif value.is_plain_value():
                return value.data
            else:
                return self.get_blob_tensor_by_name(value.name)

        if isinstance(torch_value, list):
            return [_get_simple_value(value) for value in torch_value]
        else:
            return _get_simple_value(torch_value)
Beispiel #19
0
  def basic_lstm(self, node):
    graph_scope_name = node.name.split(_GRAPH_SCOPE_SYM)[0]
    node_creator = _NodeCreator()
    graphs = []
    bidirectional = node.node_attr(node.op.AttrName.BIDIRECTIONAL)
    lstm_direction = ["forward"]
    if bidirectional:
      lstm_direction = ["forward", "backward"]
      
    for i in range(node.node_attr(node.op.AttrName.NUM_LAYERS)):
      lstm_cell_pair = {}
      if i == 0:
        input_size = node.node_attr(node.op.AttrName.INPUT_SIZE)
      else:
        input_size = len(lstm_direction) * node.node_attr(
            node.op.AttrName.HIDDEN_SIZE)

      hidden_size = node.node_attr(node.op.AttrName.HIDDEN_SIZE)
      bias=True
      for direction in lstm_direction:
        if direction == "forward":
          w_ih = node.op.params[node.op.ParamName.WEIGHT_IH][i]
          w_hh = node.op.params[node.op.ParamName.WEIGHT_HH][i]
          if node.op.ParamName.BIAS in node.op.params:
            bias_hi = node.op.params[node.op.ParamName.BIAS][i]
          else:
            bias=False
        else:
          w_ih = node.op.params[node.op.ParamName.WEIGHT_IH_REVERSE][i]
          w_hh = node.op.params[node.op.ParamName.WEIGHT_HH_REVERSE][i]
          if node.op.ParamName.BIAS_REVERSE in node.op.params:
            bias_hi = node.op.params[node.op.ParamName.BIAS_REVERSE][i]
          else:
            bias=False
      
        # lstm_node_name = node.name.replace("/", "_")
       
        graph_name = f"{graph_scope_name}_StandardLstmCell_layer_{i}_{direction}"
        graph = Graph(graph_name=graph_name)
        lstm_cell_pair[direction] = graph

        w_ii = Tensor(get_full_name(graph.name, "weight_ii"))
        w_if = Tensor(get_full_name(graph.name, "weight_if"))
        w_ig = Tensor(get_full_name(graph.name, "weight_ig"))
        w_io = Tensor(get_full_name(graph.name, "weight_io"))
        w_ii.from_ndarray(w_ih.data[:hidden_size])
        w_if.from_ndarray(w_ih.data[hidden_size:2 * hidden_size])
        w_ig.from_ndarray(w_ih.data[2 * hidden_size:3 * hidden_size])
        w_io.from_ndarray(w_ih.data[3 * hidden_size:4 * hidden_size])

        w_hi = Tensor(get_full_name(graph.name, "weight_hi"))
        w_hf = Tensor(get_full_name(graph.name, "weight_hf"))
        w_hg = Tensor(get_full_name(graph.name, "weight_hg"))
        w_ho = Tensor(get_full_name(graph.name, "weight_ho"))
        w_hi.from_ndarray(w_hh.data[:hidden_size])
        w_hf.from_ndarray(w_hh.data[hidden_size:2 * hidden_size])
        w_hg.from_ndarray(w_hh.data[2 * hidden_size:3 * hidden_size])
        w_ho.from_ndarray(w_hh.data[3 * hidden_size:4 * hidden_size])

        bias_i = Tensor(get_full_name(graph.name, "bias_i"))
        bias_f = Tensor(get_full_name(graph.name, "bias_f"))
        bias_g = Tensor(get_full_name(graph.name, "bias_g"))
        bias_o = Tensor(get_full_name(graph.name, "bias_o"))

        if bias is True:
          bias_i.from_ndarray(bias_hi.data[:hidden_size])
          bias_f.from_ndarray(bias_hi.data[hidden_size:2 * hidden_size])
          bias_g.from_ndarray(bias_hi.data[2 * hidden_size:3 * hidden_size])
          bias_o.from_ndarray(bias_hi.data[3 * hidden_size:4 * hidden_size])
       
        op = TorchBaseOperation(NNDCT_OP.INPUT, NNDCT_OP.INPUT)
        op.set_config("input", "args[0]")
        shape = [1, input_size]
        node_creator(
            graph=graph,
            node_name="input_0",
            op=op,
            num_out_tensors=1,
            shape=shape)
        op = TorchBaseOperation(NNDCT_OP.INPUT, NNDCT_OP.INPUT)
        op.set_config("input", "args[1]")
        shape = [1, hidden_size]
        node_creator(
            graph=graph,
            node_name="input_1",
            op=op,
            num_out_tensors=1,
            shape=shape)
        op = TorchBaseOperation(NNDCT_OP.INPUT, NNDCT_OP.INPUT)
        op.set_config("input", "args[2]")
        shape = [1, hidden_size]
        node_creator(
            graph=graph,
            node_name="input_2",
            op=op,
            num_out_tensors=1,
            shape=shape)
        # y_i = w_ii * input_0 + bias_i + w_hi * input_1 
        op = TorchLinear()
        op.set_config("bias", bias)
        op.set_config("out_features", hidden_size)
        op.set_config("in_features", input_size)
        op.set_param(op.ParamName.WEIGHTS, w_ii)
        if bias is True:
          op.set_param(op.ParamName.BIAS, bias_i)
        node_creator(
            graph=graph,
            node_name="w_ii * input_0 + bias_i",
            op=op,
            num_out_tensors=1,
            in_tensors=graph.node(get_full_name(graph.name, "input_0")).out_tensors)

        op = TorchLinear()
        op.set_config("bias", False)
        op.set_config("out_features", hidden_size)
        op.set_config("in_features", hidden_size)
        op.set_param(op.ParamName.WEIGHTS, w_hi)
        # op.set_param(op.ParamName.BIAS, bias_i)
        node_creator(
            graph=graph,
            node_name="w_hi * input_1",
            op=op,
            num_out_tensors=1,
            in_tensors=graph.node(get_full_name(graph.name, "input_1")).out_tensors)

        op = TorchAdd()
        op.set_config("input", graph.node(get_full_name(graph.name, "w_ii * input_0 + bias_i")).out_tensors[0])
        op.set_config("other",
                      graph.node(get_full_name(graph.name, "w_hi * input_1")).out_tensors[0])
        node_creator(
            graph=graph,
            node_name="y_i",
            op=op,
            num_out_tensors=1,
            in_tensors=[
                graph.node(get_full_name(graph.name, "w_ii * input_0 + bias_i")).out_tensors[0],
                graph.node(get_full_name(graph.name, "w_hi * input_1")).out_tensors[0]
            ])
        # y_f = w_if * input_0 + bias_f + w_hf * input_1
        op = TorchLinear()
        op.set_config("bias", bias)
        op.set_config("in_features", input_size)
        op.set_config("out_features", hidden_size)
        op.set_param(op.ParamName.WEIGHTS, w_if)
        if bias is True:
          op.set_param(op.ParamName.BIAS, bias_f)
        node_creator(
            graph=graph,
            node_name="w_if * input_0 + bias_f",
            op=op,
            num_out_tensors=1,
            in_tensors=graph.node(get_full_name(graph.name, "input_0")).out_tensors)

        op = TorchLinear()
        op.set_config("bias", False)
        op.set_config("in_features", hidden_size)
        op.set_config("out_features", hidden_size)
        op.set_param(op.ParamName.WEIGHTS, w_hf)
        # op.set_param(op.ParamName.BIAS, bias_f)
        node_creator(
            graph=graph,
            node_name="w_hf * input_1",
            op=op,
            num_out_tensors=1,
            in_tensors=graph.node(get_full_name(graph.name, "input_1")).out_tensors)

        op = TorchAdd()
        op.set_config("input", graph.node(get_full_name(graph.name, "w_if * input_0 + bias_f")).out_tensors[0])
        op.set_config("other",
                      graph.node(get_full_name(graph.name, "w_hf * input_1")).out_tensors[0])
        node_creator(
            graph=graph,
            node_name="y_f",
            op=op,
            num_out_tensors=1,
            in_tensors=[
                graph.node(get_full_name(graph.name, "w_if * input_0 + bias_f")).out_tensors[0],
                graph.node(get_full_name(graph.name, "w_hf * input_1")).out_tensors[0]
            ])

        # y_g = w_ig * input_0 + bias_g + w_hg * input_1
        op = TorchLinear()
        op.set_config("bias", bias)
        op.set_config("in_features", input_size)
        op.set_config("out_features", hidden_size)
        op.set_param(op.ParamName.WEIGHTS, w_ig)
        if bias is True:
          op.set_param(op.ParamName.BIAS, bias_g)
        node_creator(
            graph=graph,
            node_name="w_ig * input_0 + bias_g",
            op=op,
            num_out_tensors=1,
            in_tensors=graph.node(get_full_name(graph.name, "input_0")).out_tensors)

        op = TorchLinear()
        op.set_config("bias", False)
        op.set_config("in_features", hidden_size)
        op.set_config("out_features", hidden_size)
        op.set_param(op.ParamName.WEIGHTS, w_hg)
        # op.set_param(op.ParamName.BIAS, bias_g)
        node_creator(
            graph=graph,
            node_name="w_hg * input_1",
            op=op,
            num_out_tensors=1,
            in_tensors=graph.node(get_full_name(graph.name, "input_1")).out_tensors)

        op = TorchAdd()
        op.set_config("input", graph.node(get_full_name(graph.name, "w_ig * input_0 + bias_g")).out_tensors[0])
        op.set_config("other",
                      graph.node(get_full_name(graph.name, "w_hg * input_1")).out_tensors[0])
        node_creator(
            graph=graph,
            node_name="y_g",
            op=op,
            num_out_tensors=1,
            in_tensors=[
                graph.node(get_full_name(graph.name, "w_ig * input_0 + bias_g")).out_tensors[0],
                graph.node(get_full_name(graph.name, "w_hg * input_1")).out_tensors[0]
            ])

        # y_o = w_io * input_0 +  bias_o + w_ho * input_1
        op = TorchLinear()
        op.set_config("bias", bias)
        op.set_config("in_features", input_size)
        op.set_config("out_features", hidden_size)
        op.set_param(op.ParamName.WEIGHTS, w_io)
        if bias is True:
          op.set_param(op.ParamName.BIAS, bias_o)
        node_creator(
            graph=graph,
            node_name="w_io * input_0 + bias_o",
            op=op,
            num_out_tensors=1,
            in_tensors=graph.node(get_full_name(graph.name, "input_0")).out_tensors)

        op = TorchLinear()
        op.set_config("bias", False)
        op.set_config("in_features", hidden_size)
        op.set_config("out_features", hidden_size)
        op.set_param(op.ParamName.WEIGHTS, w_ho)
        # op.set_param(op.ParamName.BIAS, bias_o)
        node_creator(
            graph=graph,
            node_name="w_ho * input_1",
            op=op,
            num_out_tensors=1,
            in_tensors=graph.node(get_full_name(graph.name, "input_1")).out_tensors)

        op = TorchAdd()
        op.set_config("input", graph.node(get_full_name(graph.name, "w_io * input_0 + bias_o")).out_tensors[0])
        op.set_config("other",
                      graph.node(get_full_name(graph.name, "w_ho * input_1")).out_tensors[0])

        node_creator(
            graph=graph,
            node_name="y_o",
            op=op,
            num_out_tensors=1,
            in_tensors=[
                graph.node(get_full_name(graph.name, "w_io * input_0 + bias_o")).out_tensors[0],
                graph.node(get_full_name(graph.name, "w_ho * input_1")).out_tensors[0]
            ])

        # op = Split(optype=NNDCT_OP.SPLIT)
        # op.set_attr(op.AttrName.INPUT, graph.node("combine_2_linearity").out_tensors[0])
        # op.set_attr(op.AttrName.SPLIT_SIZE_OR_SECTIONS, hidden_size)
        # op.set_attr(op.AttrName.AXIS, 1)
        # node_creator(graph=graph,
        #               node_name="split_ifgo",
        #               op=op,
        #               num_out_tensors=4,
        #               in_tensors=graph.node("combine_2_linearity").out_tensors)

        op = TorchSigmoid()
        node_creator(
            graph=graph,
            node_name="it",
            op=op,
            num_out_tensors=1,
            in_tensors=[graph.node(get_full_name(graph.name, "y_i")).out_tensors[0]])

        op = TorchSigmoid()
        node_creator(
            graph=graph,
            node_name="ft",
            op=op,
            num_out_tensors=1,
            in_tensors=[graph.node(get_full_name(graph.name, "y_f")).out_tensors[0]])

        op = TorchTanh()
        node_creator(
            graph=graph,
            node_name="cct",
            op=op,
            num_out_tensors=1,
            in_tensors=[graph.node(get_full_name(graph.name, "y_g")).out_tensors[0]])

        op = TorchSigmoid()
        node_creator(
            graph=graph,
            node_name="ot",
            op=op,
            num_out_tensors=1,
            in_tensors=[graph.node(get_full_name(graph.name, "y_o")).out_tensors[0]])

        op = TorchMul()
        op.set_config("input", graph.node(get_full_name(graph.name, "it")).out_tensors[0])
        op.set_config("other", graph.node(get_full_name(graph.name, "cct")).out_tensors[0])

        node_creator(
            graph=graph,
            node_name="it*cct",
            op=op,
            num_out_tensors=1,
            in_tensors=[
                graph.node(get_full_name(graph.name, "it")).out_tensors[0],
                graph.node(get_full_name(graph.name, "cct")).out_tensors[0]
            ])

        op = TorchMul()
        op.set_config("input", graph.node(get_full_name(graph.name, "ft")).out_tensors[0])
        op.set_config("other", graph.node(get_full_name(graph.name, "input_2")).out_tensors[0])
        node_creator(
            graph=graph,
            node_name="ft*input_2",
            op=op,
            num_out_tensors=1,
            in_tensors=[
                graph.node(get_full_name(graph.name, "ft")).out_tensors[0],
                graph.node(get_full_name(graph.name, "input_2")).out_tensors[0]
            ])

        op = TorchAdd()
        op.set_config("input", graph.node(get_full_name(graph.name, "it*cct")).out_tensors[0])
        op.set_config("other", graph.node(get_full_name(graph.name, "ft*input_2")).out_tensors[0])
        node_creator(
            graph=graph,
            node_name="c_next",
            op=op,
            num_out_tensors=1,
            in_tensors=[
                graph.node(get_full_name(graph.name, "it*cct")).out_tensors[0],
                graph.node(get_full_name(graph.name, "ft*input_2")).out_tensors[0]
            ])

        op = TorchTanh()
        node_creator(
            graph=graph,
            node_name="c_temp",
            op=op,
            num_out_tensors=1,
            in_tensors=graph.node(get_full_name(graph.name, "c_next")).out_tensors)

        op = TorchMul()
        op.set_config("input", graph.node(get_full_name(graph.name, "ot")).out_tensors[0])
        op.set_config("other", graph.node(get_full_name(graph.name, "c_temp")).out_tensors[0])
        node_creator(
            graph=graph,
            node_name="h_next",
            op=op,
            num_out_tensors=1,
            in_tensors=[
                graph.node(get_full_name(graph.name, "ot")).out_tensors[0],
                graph.node(get_full_name(graph.name, "c_temp")).out_tensors[0]
            ])
        self._connect_nodes(graph)
        graph.add_end_tensor(graph.node(get_full_name(graph.name, "h_next")).out_tensors[0])
        graph.add_end_tensor(graph.node(get_full_name(graph.name, "c_next")).out_tensors[0])
      graphs.append(lstm_cell_pair)
    return graphs
Beispiel #20
0
class DevGraphOptimizer(object):
    """Optimze graph for device computation
 
  """
    def __init__(self, nndct_graph):
        self._dev_graph = Graph(graph_name=nndct_graph.name)
        self._dev_graph.clone_from(nndct_graph)
        self._evalute_func_map = {
            NNDCT_OP.SHAPE: Evaluator.shape,
            NNDCT_OP.CAST: Evaluator.cast,
            NNDCT_OP.INT: Evaluator.int,
            NNDCT_OP.SCALAR_MUL: Evaluator.mul,
            NNDCT_OP.TENSOR: Evaluator.tensor,
            NNDCT_OP.FLOOR: Evaluator.floor,
            NNDCT_OP.DIV: Evaluator.elemwise_div,
            NNDCT_OP.FLOOR_DIV: Evaluator.floor_div,
            NNDCT_OP.ADD: Evaluator.add,
            NNDCT_OP.SCALAR_ADD: Evaluator.add
        }

    def strip_redundant_ops(self):
        # remove unsupported op in xmodel
        redundant_op_types = [NNDCT_OP.CONTIGUOUS]
        self._dev_graph.remove_node_by_types(redundant_op_types)

    def update_op_attrs(self):
        for node in self._dev_graph.all_nodes():
            if node.op.type == NNDCT_OP.STRIDED_SLICE:
                input_dims = node.in_tensors[0].ndim
                begin = [0] * input_dims
                last = [NNDCT_CONSTANT.INT_MAX] * input_dims
                strides = [1] * input_dims
                dims = node.node_attr(node.op.AttrName.DIMS)
                start = node.node_attr(node.op.AttrName.BEGIN)
                step = node.node_attr(node.op.AttrName.STRIDES)
                end = node.node_attr(node.op.AttrName.END)
                for i, pos in enumerate(dims):
                    begin[pos] = start[i]
                    if isinstance(end[i], Tensor) or (isinstance(end[i], int)
                                                      and end[i] < last[pos]):
                        last[pos] = end[i]

                    strides[pos] = step[i]

                begin_mask = 0
                for dim, pos in enumerate(begin):
                    if pos == 0:
                        begin_mask |= 1 << dim

                end_mask = 0
                for dim, pos in enumerate(end):
                    if isinstance(pos, int) and pos >= NNDCT_CONSTANT.INT_MAX:
                        end_mask |= 1 << dim

                node.set_node_attr(node.op.AttrName.BEGIN, begin)
                node.set_node_attr(node.op.AttrName.BEGIN_MASK, begin_mask)
                node.set_node_attr(node.op.AttrName.END, last)
                node.set_node_attr(node.op.AttrName.END_MASK, end_mask)
                node.set_node_attr(node.op.AttrName.STRIDES, strides)
            elif node.op.type in [
                    NNDCT_OP.SQUEEZE, NNDCT_OP.SUM, NNDCT_OP.MAX, NNDCT_OP.MEAN
            ]:
                new_dims = []
                input_dims = node.in_tensors[0].ndim
                for dim in node.node_attr(node.op.AttrName.DIMS):
                    if dim < 0:
                        new_dims.append(input_dims + dim)
                    else:
                        new_dims.append(dim)
                node.set_node_attr(node.op.AttrName.DIMS, new_dims)

            elif node.op.type == NNDCT_OP.TRANSPOSE:
                input_dims = node.in_tensors[0].ndim
                new_order = list(range(input_dims))
                transpose_order = node.node_attr(node.op.AttrName.ORDER)
                tmp = new_order[transpose_order[0]]
                new_order[transpose_order[0]] = new_order[transpose_order[1]]
                new_order[transpose_order[1]] = tmp
                node.set_node_attr(node.op.AttrName.ORDER, new_order)
            elif node.op.type in [
                    NNDCT_OP.CONCAT, NNDCT_OP.SHAPE, NNDCT_OP.SOFTMAX
            ]:
                input_dims = node.in_tensors[0].ndim
                dim = node.node_attr(node.op.AttrName.AXIS)
                if dim < 0:
                    dim = input_dims + dim
                    node.set_node_attr(node.op.AttrName.AXIS, dim)

            elif node.op.type == NNDCT_OP.ADAPTIVEAVGPOOL2D:
                input_size = node.in_tensors[0].shape  # NCHW
                kernel = [input_size[3], input_size[2]]
                node.set_node_attr(node.op.AttrName.KERNEL, kernel)
                node.set_node_attr(node.op.AttrName.STRIDE, kernel)

    def constant_folding(self):
        folding_nodes = set()
        for node in self._dev_graph.nodes:
            if node.in_quant_part is False:
                continue
            if hasattr(node.op, "AttrName") and node.op.type not in [
                    NNDCT_OP.ADD, NNDCT_OP.SUB, NNDCT_OP.MULTIPLY, NNDCT_OP.DIV
            ]:
                # TODO: Add condition when node.op.type is NNDCT_OP.DIV
                for attr_name in node.op.attrs.keys():
                    attr_val = node.node_attr(attr_name)
                    if isinstance(attr_val, list):
                        for i, val in enumerate(attr_val):
                            attr_val[i] = self._materialize(
                                node, val, folding_nodes)
                    else:
                        attr_val = self._materialize(node, attr_val,
                                                     folding_nodes)
                    if node.op.attrs[attr_name].type == list:
                        attr_val = [attr_val]
                    node.set_node_attr(attr_name, attr_val)

        if folding_nodes:
            for node_name in folding_nodes:
                node = self._dev_graph.node(node_name)
                for out in node.out_tensors:
                    while out.uses:
                        out.uses[0].user.remove_input(out.uses[0].offset)

                self._dev_graph.node(node_name).destroy()

            self._dev_graph.reconnect_nodes()

    @staticmethod
    def _infer_op_value_immediately(op_type):
        return op_type in [NNDCT_OP.SHAPE, NNDCT_OP.CONST]

    def _eval_node_value(self, node):
        if node.out_tensors[0].data is None:
            self._evalute_func_map[node.op.type](node)

    def _materialize(self, cur_node, value, folding_nodes):
        visited = set()

        def dfs(node):
            visited.add(node.name)
            if self._infer_op_value_immediately(node.op.type):
                folding_nodes.add(node.name)
                self._eval_node_value(node)
                return True
            elif hasattr(node, "const_folding") and node.const_folding is True:
                folding_nodes.add(node.name)
                self._eval_node_value(node)
                return True
            elif node.op.type not in self._evalute_func_map:
                return False

            find_evaluable_op = False
            for tensor in node.in_tensors:
                if tensor.node and tensor.node.name not in visited:  # and tensor.data is None:
                    find_evaluable_op = dfs(tensor.node)
                    if find_evaluable_op is False:
                        break

            if find_evaluable_op:
                folding_nodes.add(node.name)
                self._eval_node_value(node)

            return find_evaluable_op

        if not isinstance(value, Tensor):
            return value
        else:
            # if hasattr(value.node, "const_folding") and value.node.const_folding is True:
            #   folding_nodes.add(value.node.name)
            is_evaluable = dfs(value.node)
            if is_evaluable:
                data = value.node.out_tensors[0].data
                input_idx = cur_node.in_tensors.index(value)
                cur_node.remove_input(input_idx)

                if not cur_node.in_tensors and cur_node.op.type not in [
                        NNDCT_OP.ZEROS, NNDCT_OP.QUANT_STUB
                ]:
                    cur_node.const_folding = True

                return data
            else:
                return value

    def layout_tranform(self):
        """layout_transform TORCH(NCHW) -> XIR(NHWC)"""

        custom2xir = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_TO_XIR_LIST)
        if custom2xir is None:
            custom2xir = []

        def _find_swim_order(ndim):
            return {
                2: [0, 1],
                3: [0, 2, 1],
                4: [0, 2, 3, 1],
                5: [0, 3, 4, 2, 1]
            }[ndim]

        def _find_sink_order(ndim):
            return {
                2: [0, 1],
                3: [0, 2, 1],
                4: [0, 3, 1, 2],
                5: [0, 4, 3, 1, 2]
            }[ndim]

        def _is_dim_transparent(node):
            return node.in_tensors[0].ndim and node.out_tensors[
                0].ndim and node.in_tensors[0].ndim == node.out_tensors[0].ndim

        def _is_shape_transparent(node):
            return node.in_tensors[0].shape and node.out_tensors[
                0].shape and node.in_tensors[0].shape == node.out_tensors[
                    0].shape

        def _have_special_layout(node):
            return node.out_tensors[0].ndim and node.out_tensors[0].ndim >= 3

        def _is_custom_op(node):
            return isinstance(
                node.op, base_op.CustomOp) and node.op.type not in custom2xir

        def _is_permute_op(node):
            return isinstance(node.op, base_op.Permute)

        def _is_terminate_op(node):
            return node.op.type == NNDCT_OP.RETURN

        implicit_ops = [
            NNDCT_OP.CONV2D, NNDCT_OP.DEPTHWISE_CONV2D,
            NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D, NNDCT_OP.CONVTRANSPOSE2D,
            NNDCT_OP.MAX_POOL, NNDCT_OP.AVG_POOL, NNDCT_OP.ADAPTIVEAVGPOOL2D,
            NNDCT_OP.INTERPOLATE, NNDCT_OP.UP_SAMPLING, NNDCT_OP.RESIZE,
            NNDCT_OP.BATCH_NORM, NNDCT_OP.MAX_POOL1D, NNDCT_OP.CONV1D,
            NNDCT_OP.CONV3D, NNDCT_OP.DEPTHWISE_CONV3D,
            NNDCT_OP.DEPTHWISE_CONVTRANSPOSE3D, NNDCT_OP.CONVTRANSPOSE3D,
            NNDCT_OP.PIXEL_SHUFFLE, NNDCT_OP.PIXEL_UNSHUFFLE,
            NNDCT_OP.RESIZE_3D, NNDCT_OP.RESIZE_NEAREST_3D, NNDCT_OP.REORG,
            NNDCT_OP.CORRELATION1D_ELEMWISE, NNDCT_OP.CORRELATION2D_ELEMWISE,
            NNDCT_OP.COST_VOLUME
        ]

        special_ops_fn = {
            NNDCT_OP.RESHAPE: shape_attr_transform_fn,
            NNDCT_OP.CONCAT: axis_attr_transform_fn,
            NNDCT_OP.STRIDED_SLICE: slice_attr_transform_fn,
            NNDCT_OP.SUM: reduce_op_attr_transform_fn,
            NNDCT_OP.MAX: reduce_op_attr_transform_fn,
            NNDCT_OP.MEAN: reduce_op_attr_transform_fn,
            NNDCT_OP.SHAPE: axis_attr_transform_fn,
            NNDCT_OP.SOFTMAX: axis_attr_transform_fn,
            NNDCT_OP.ZEROS: shape_attr_transform_fn,
        }

        # collect insert point for transpose
        insert_pos = []
        for node in self._dev_graph.nodes:
            if node.op.type in implicit_ops:
                insert_pos.append(node)

        swim_transpose = defaultdict(list)
        swim_in_transpose = defaultdict(list)
        sink_transpose = defaultdict(list)

        for node in insert_pos:
            tranpose_out_order = tuple(
                _find_swim_order(node.out_tensors[0].ndim))
            swim_transpose[tranpose_out_order].append(node)
            tranpose_in_order = tuple(_find_swim_order(
                node.in_tensors[0].ndim))
            swim_in_transpose[node] = tranpose_in_order
            tranpose_out_order = tuple(
                _find_sink_order(node.out_tensors[0].ndim))
            sink_transpose[tranpose_out_order].append(node)

        nodes_need_to_remove = []
        transpose_insert_between_swim = defaultdict(list)
        visited = []
        # swim_transpose_order, nodes = next(iter(swim_transpose.items()))
        for swim_transpose_order, nodes in swim_transpose.items():
            for insert_node in nodes:
                q = deque()
                q.append(insert_node)
                visited.append(insert_node)
                insert_node.transpose_out_order = swim_transpose_order
                insert_node.transpose_in_order = swim_in_transpose[insert_node]
                while len(q) > 0:
                    node = q.popleft()
                    for pn in self._dev_graph.parents(node):
                        if pn not in visited:

                            if not _have_special_layout(
                                    pn) or pn.op.type in implicit_ops:
                                continue

                            elif pn.op.type in [
                                    NNDCT_OP.INPUT, NNDCT_OP.QUANT_STUB,
                                    NNDCT_OP.CONST, NNDCT_OP.ZEROS
                            ] or _is_dim_transparent(pn) and (
                                    not _is_permute_op(pn)) and (
                                        not _is_custom_op(pn)):
                                pn.transpose_out_order = node.transpose_in_order
                                pn.transpose_in_order = pn.transpose_out_order
                                if pn.op.type in special_ops_fn:
                                    special_ops_fn[pn.op.type](
                                        pn, pn.transpose_out_order)
                                q.append(pn)
                                visited.append(pn)

                            else:
                                # pn.transpose_out_order = [0, 2, 3, 1]
                                transpose_insert_between_swim[
                                    swim_transpose_order].append((pn, node))

        index = 0
        for transpose_order, node_pairs in transpose_insert_between_swim.items(
        ):
            for pn, cn in node_pairs:
                node_name = "_".join([pn.name, "swim_transpose", f"{index}"])
                op = base_op.Permute(NNDCT_OP.PERMUTE)
                new_node = Node(node_name,
                                op=op,
                                dtype=pn.dtype,
                                in_quant_part=pn.in_quant_part)
                new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                       list(transpose_order))
                self._dev_graph.insert_node_between_nodes(new_node, pn, cn)
                nodes_need_to_remove.append(new_node)
                index += 1

        if transpose_insert_between_swim:
            self._dev_graph.reconnect_nodes()

        # debug
        # print("#####swim######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)

        transpose_insert_between_sink = defaultdict(list)
        visited = []
        for node in self._dev_graph.nodes:
            if node.transpose_out_order:
                nodes = sink_transpose[tuple(
                    _find_sink_order(len(node.transpose_out_order)))]
                if node not in nodes:
                    nodes.append(node)

        for sink_transpose_order, nodes in sink_transpose.items():
            for insert_node in nodes:
                if insert_node not in visited:
                    q = deque()
                    q.append(insert_node)
                    visited.append(insert_node)
                    while len(q) > 0:
                        node = q.popleft()
                        for cn in self._dev_graph.children(node):
                            if cn not in visited:
                                if cn.op.type in implicit_ops or _is_terminate_op(
                                        cn):
                                    continue
                                elif cn.op.type == NNDCT_OP.SHAPE:
                                    visited.append(cn)
                                    if node.transpose_out_order:
                                        special_ops_fn[cn.op.type](
                                            cn, node.transpose_out_order)
                                        continue
                                elif cn.transpose_out_order:
                                    q.append(cn)
                                    visited.append(cn)
                                elif _is_dim_transparent(cn) and (
                                        not _is_permute_op(cn)) and (
                                            not _is_custom_op(cn)):
                                    cn.transpose_in_order = node.transpose_out_order
                                    cn.transpose_out_order = cn.transpose_in_order
                                    q.append(cn)
                                    visited.append(cn)
                                    if cn.op.type in special_ops_fn:
                                        special_ops_fn[cn.op.type](
                                            cn, cn.transpose_out_order)
                                else:
                                    transpose_insert_between_sink[
                                        sink_transpose_order].append(
                                            (node, cn))

        index = 0
        for transpose_order, node_pairs in transpose_insert_between_sink.items(
        ):
            for pn, cn in node_pairs:

                node_name = "_".join([pn.name, "sink_transpose", f"{index}"])
                op = base_op.Permute(NNDCT_OP.PERMUTE)
                new_node = Node(node_name,
                                op=op,
                                dtype=pn.dtype,
                                in_quant_part=cn.in_quant_part)
                new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                       list(transpose_order))
                self._dev_graph.insert_node_between_nodes(new_node, pn, cn)

                nodes_need_to_remove.append(new_node)
                index += 1

        if transpose_insert_between_sink:
            self._dev_graph.reconnect_nodes()

        # debug
        # print("#####sink######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)
        neighbor_broadcast = {}
        for node in self._dev_graph.nodes:
            if len(node.in_nodes) <= 1 or node in implicit_ops:
                continue
            if all([
                    node.transpose_out_order is None
                    for node in self._dev_graph.parents(node)
            ]) or all([
                    node.transpose_out_order is not None
                    for node in self._dev_graph.parents(node)
            ]):
                continue
            #if node.out_tensors[0].dtype != "float32":
            #  continue
            transpose_order = None
            for pn in self._dev_graph.parents(node):
                transpose_order = pn.transpose_out_order
                if transpose_order is not None:
                    break

            neighbor_broadcast[node] = transpose_order

        have_neighbors = False
        for node, transpose_order in neighbor_broadcast.items():
            index = 0
            for pn in self._dev_graph.parents(node):
                if pn.transpose_out_order is None and pn.out_tensors[
                        0].ndim and node.out_tensors[0].ndim and pn.out_tensors[
                            0].ndim == node.out_tensors[0].ndim:
                    # pn.transpose_out_order = node.transpose_out_order
                    node_name = "_".join(
                        [node.name, "neighbor_transpose", f"{index}"])
                    op = base_op.Permute(NNDCT_OP.PERMUTE)
                    new_node = Node(node_name,
                                    op=op,
                                    dtype=node.dtype,
                                    in_quant_part=pn.in_quant_part)
                    new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                           list(transpose_order))
                    self._dev_graph.insert_node_between_nodes(
                        new_node, pn, node)

                    index += 1

                    nodes_need_to_remove.append(new_node)
                    have_neighbors = True

        if have_neighbors:
            self._dev_graph.reconnect_nodes()

        # Debug
        # print("####neightbor######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)
        # remove consecutive transpose

        def merge_father_and_child(node, visited, transpose_group,
                                   reserverd_nodes):
            visited.append(node)
            if _is_permute_op(node):
                if node.out_nodes and all([
                        _is_permute_op(cn)
                        for cn in self._dev_graph.children(node)
                ]):
                    transpose_group.append(node)
                else:
                    transpose_group.append(node)

                    order = []
                    reserved_trans = None
                    for trans in transpose_group:
                        if trans not in nodes_need_to_remove:
                            reserved_trans = trans

                        if not order:
                            order = trans.node_attr(trans.op.AttrName.ORDER)
                        else:
                            new_order = len(order) * [None]
                            tmp_order = trans.node_attr(
                                trans.op.AttrName.ORDER)
                            for i in range(len(order)):
                                t_i = tmp_order[i]
                                new_order[i] = order[t_i]
                            order = new_order

                    if reserved_trans is None:
                        reserved_trans = transpose_group[-1]

                    reserved_trans.set_node_attr(
                        reserved_trans.op.AttrName.ORDER, order)
                    reserverd_nodes.append(reserved_trans)

                    transpose_group.clear()

            for cn in self._dev_graph.children(node):
                if cn not in visited:
                    merge_father_and_child(cn, visited, transpose_group,
                                           reserverd_nodes)

        def merge_brothers(reserverd_nodes):
            remove_nodes = []
            for node in self._dev_graph.nodes:
                if len(node.out_nodes) > 1 and all([
                        _is_permute_op(cn)
                        for cn in self._dev_graph.children(node)
                ]):
                    need_merge = True
                    order = None
                    for trans_node in self._dev_graph.children(node):
                        if order is not None:
                            if order != trans_node.node_attr(
                                    trans_node.op.AttrName.ORDER):
                                need_merge = False
                                break
                        else:
                            order = trans_node.node_attr(
                                trans_node.op.AttrName.ORDER)

                    if need_merge:
                        reserverd_node = None
                        for trans_node in self._dev_graph.children(node):
                            if trans_node not in nodes_need_to_remove:
                                reserverd_node = trans_node

                        if reserverd_node is None:
                            reserverd_node = self._dev_graph.children(node)[0]

                        for trans_node in self._dev_graph.children(node):
                            if trans_node is not reserverd_node and trans_node in reserverd_nodes:
                                remove_nodes.append(trans_node)

                                out_tensor = trans_node.out_tensors[0]
                                out_tensor.replace_uses_with(
                                    reserverd_node.out_tensors[0])

            for node in remove_nodes:
                node.destroy()

            if remove_nodes:
                self._dev_graph.reconnect_nodes()

        source_nodes = []
        for node in self._dev_graph.nodes:
            if not node.in_tensors:
                source_nodes.append(node)

        transpose_group = []
        reserverd_nodes = []
        visited = []
        for source in source_nodes:
            merge_father_and_child(source, visited, transpose_group,
                                   reserverd_nodes)

        nodes_need_to_remove = [
            node for node in nodes_need_to_remove
            if node not in reserverd_nodes
        ]

        for node in reserverd_nodes:
            order = node.node_attr(node.op.AttrName.ORDER)
            keep_order = True
            if any([index != dim for index, dim in enumerate(order)]):
                keep_order = False
            if keep_order:
                nodes_need_to_remove.append(node)

        for node in nodes_need_to_remove:
            self._dev_graph.remove_node(node)

        merge_brothers(reserverd_nodes)

        # debug
        # print("#####finalize######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)

        def delete_transpose_of_correlation(self):
            nodes_need_to_delete_for_special_ops = []
            nodes_need_to_insert_aster_special_ops = []
            nodes_need_to_merge_for_special_ops = []
            for node in self._dev_graph.nodes:
                if node.op.type == NNDCT_OP.MEAN and not node.node_attr(
                        node.op.AttrName.KEEP_DIMS
                ) and self._dev_graph.parents(node):
                    pn = self._dev_graph.parents(node)[0]
                    if pn.in_tensors and _is_permute_op(
                            pn) and self._dev_graph.parents(pn):
                        gpn = self._dev_graph.parents(pn)[0]
                        if gpn.op.type in [
                                NNDCT_OP.CORRELATION1D_ELEMWISE,
                                NNDCT_OP.CORRELATION2D_ELEMWISE
                        ] and node.out_tensors[0].ndim and gpn.out_tensors[
                                0].ndim == 5 and node.out_tensors[0].ndim == 4:

                            nodes_need_to_delete_for_special_ops.append(pn)

                            node.transpose_in_order = tuple(
                                _find_swim_order(5))
                            node.transpose_out_order = tuple(
                                _find_swim_order(4))
                            special_ops_fn[node.op.type](
                                node, node.transpose_in_order)

                            nodes_need_to_insert_aster_special_ops.append(node)
            index = 0
            for node in nodes_need_to_insert_aster_special_ops:
                cn = self._dev_graph.children(node)[0]
                node_name = "_".join([node.name, "sink_transpose", f"{index}"])
                op = base_op.Permute(NNDCT_OP.PERMUTE)
                new_node = Node(node_name,
                                op=op,
                                dtype=node.dtype,
                                in_quant_part=node.in_quant_part)
                new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                       tuple(_find_sink_order(4)))
                self._dev_graph.insert_node_between_nodes(new_node, node, cn)
                nodes_need_to_merge_for_special_ops.append(new_node)
                index += 1

            for node in nodes_need_to_delete_for_special_ops:
                self._dev_graph.remove_node(node)

            source_nodes = []
            for node in self._dev_graph.nodes:
                if not node.in_tensors:
                    source_nodes.append(node)

            transpose_group = []
            reserverd_nodes = []
            visited = []
            for source in nodes_need_to_merge_for_special_ops:
                merge_father_and_child(source, visited, transpose_group,
                                       reserverd_nodes)

            nodes_need_to_merge_for_special_ops = [
                node for node in nodes_need_to_merge_for_special_ops
                if node not in reserverd_nodes
            ]

            for node in reserverd_nodes:
                order = node.node_attr(node.op.AttrName.ORDER)
                keep_order = True
                if any([index != dim for index, dim in enumerate(order)]):
                    keep_order = False
                if keep_order:
                    nodes_need_to_merge_for_special_ops.append(node)

            for node in nodes_need_to_merge_for_special_ops:
                self._dev_graph.remove_node(node)

            merge_brothers(reserverd_nodes)

        delete_transpose_of_correlation(self)

    def partition_by_quant_part(self) -> List[List[Graph]]:
        if not any([
                node.op.type == NNDCT_OP.QUANT_STUB
                for node in self._dev_graph.nodes
        ]):
            return [[self._dev_graph]]

        id2nodes = defaultdict(set)

        def collect_node_set(node, set_id, visited=None):

            if visited is None:
                visited = []

            if node.op.type == NNDCT_OP.RETURN:
                return

            if not hasattr(node, "set_id"):
                node.set_id = set_id

            id2nodes[set_id].add(node)
            visited.append(node)

            for cn in self._dev_graph.children(node):
                if cn not in visited and cn.in_quant_part:
                    collect_node_set(cn, set_id, visited)

        def get_set_id_from_nodeset(nodeset):
            return min([node.set_id for node in nodeset])

        def partition_check(quant_graphs, node_graph_id):
            for node_name, graph_id in node_graph_id.items():
                if len(graph_id) > 1:
                    NndctScreenLogger().error(
                        f"The subgraph{graph_id} hold {node_name} at the same time."
                    )
            for node in self._dev_graph.nodes:
                if node.op.type == NNDCT_OP.RETURN:
                    continue
                if node.in_quant_part and all(
                    [node not in graph for graph in quant_graphs]):
                    raise RuntimeError(
                        f"Please check graph partition: the quant node '{node.name}' should be in quant graph."
                    )
                elif not node.in_quant_part and any(
                    [node in graph for graph in quant_graphs]):
                    raise RuntimeError(
                        f"Please check graph partition: the non-quant node '{node.name}' included in quant graph."
                    )

        set_id = 0
        for node in self._dev_graph.nodes:
            visited = []
            if node.op.type == NNDCT_OP.QUANT_STUB or (not node.in_nodes
                                                       and node.in_quant_part):
                collect_node_set(node, set_id, visited)
                set_id += 1

        merged_id2nodes = defaultdict(set)
        for _, nodeset in id2nodes.items():
            id = get_set_id_from_nodeset(nodeset)
            merged_id2nodes[id].update(nodeset)

        quant_dev_graph = []
        node_graph_id = defaultdict(list)
        for graph_id, nodes in merged_id2nodes.items():
            for node in nodes:
                node_graph_id[node.name].append(graph_id)
            subgraph = Graph.create_subgraph_from_nodeset(
                self._dev_graph, nodes, f"{self._dev_graph.name}_{graph_id}")
            quant_dev_graph.append(subgraph)

        partition_check(quant_dev_graph, node_graph_id)
        if NndctOption.nndct_dump_no_quant_part.value:
            return [quant_dev_graph, [self._dev_graph]]
        else:
            return [quant_dev_graph]

    @property
    def dev_graph(self):
        return self._dev_graph