Example #1
0
            def forward(self, *args):

                inputs = []

                def collect_inputs(inputs, value):
                    if isinstance(value, torch.Tensor):
                        inputs.append(value)
                    elif isinstance(value, (tuple, list)):
                        for i in value:
                            collect_inputs(inputs, i)

                for v in args:
                    collect_inputs(inputs, v)

                inptus, _ = process_inputs_and_params(self.node,
                                                      self.quantizer,
                                                      inputs=inputs)

                caller_map = GLOBAL_MAP.get_ele(NNDCT_KEYS.NODE_CALLER_MAP)
                output = caller_map[self.node.name](*args)
                [output] = post_quant_process(self.node, [output])

                return output
Example #2
0
            def forward(self, *args):

                inputs = []

                def collect_inputs(inputs, value):
                    if isinstance(value, torch.Tensor):
                        inputs.append(value)
                    elif isinstance(value, (tuple, list)):
                        for i in value:
                            collect_inputs(inputs, i)

                for v in args:
                    collect_inputs(inputs, v)

                inputs = quantize_tensors(inputs,
                                          self.node,
                                          tensor_type='input')

                caller_map = GLOBAL_MAP.get_ele(NNDCT_KEYS.NODE_CALLER_MAP)
                output = caller_map[self.node.name](*args)

                output = quantize_tensors([output], self.node)[0]

                return output
Example #3
0
    def decorate(func):
        if op_type in NNDCT_OP.__dict__.values() and (not mapping_to_xir):
            NndctScreenLogger().error(
                f"'{op_type}' has been defined in pytorch_nndct, please use other type name."
            )
            exit(1)
        if not inspect.isfunction(func):
            RuntimeError("This api only decorate a function object")

        custom_op_attr_map = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_OP_ATTRS_MAP)
        if custom_op_attr_map is None:
            custom_op_attr_map = {}
            GLOBAL_MAP.set_map(NNDCT_KEYS.CUSTOM_OP_ATTRS_MAP,
                               custom_op_attr_map)
        if op_type in custom_op_attr_map:
            NndctScreenLogger().error(
                f"'{op_type}' can't be registered multiple times.")
        else:
            custom_op_attr_map[
                op_type] = attrs_list if attrs_list is not None else []

        if mapping_to_xir is True:
            NndctScreenLogger().info(f'`{op_type}` has been mapped to xir.')
            custom2xir = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_TO_XIR_LIST)
            if custom2xir is None:
                custom2xir = []
                GLOBAL_MAP.set_map(NNDCT_KEYS.CUSTOM_TO_XIR_LIST, custom2xir)
            if op_type not in custom2xir:
                custom2xir.append(op_type)
            else:
                raise RuntimeError(
                    f"{op_type} has alrealy been mapped to XIR. Please use this op type instead of custom op."
                )

        @functools.wraps(func)
        def innner(*args, **kwargs):
            custom_op = types.new_class(op_type, (torch.autograd.Function, ),
                                        {})
            custom_op.forward = staticmethod(func)
            return custom_op.apply(*args, **kwargs)

        return innner
Example #4
0
 def forward(self, *args):
     caller_map = GLOBAL_MAP.get_ele(NNDCT_KEYS.NODE_CALLER_MAP)
     output = caller_map[self.node.name](*args)
     [output] = post_quant_process(self.node, [output])
     return output
Example #5
0
    def do_compile(
            compile_graph: Graph,
            output_file_name=None,
            quant_config_info: Optional[NndctQuantInfo] = None,
            graph_attr_kwargs: Optional[Dict[str, Any]] = None) -> NoReturn:
        r""" convert nndct graph to xmodel"""
        # debug
        # for type, bnfp in quant_config_info.items():
        #   print(f"{type}\n")
        #   for name, bnfp_value in bnfp.items():
        #     print(f"{name}:{bnfp_value}\n")

        if NndctOption.nndct_quant_off.value:
            quant_config_info = None

        xgraph = XGraph(compile_graph.name)

        if graph_attr_kwargs is not None:
            for name, attr in graph_attr_kwargs.items():
                xgraph.graph.set_attr(name, attr)

        for node in compile_graph.nodes:
            for param_type, param_tensor in node.op.params.items():
                if (node.op.type == NNDCT_OP.BATCH_NORM and param_type not in [
                        node.op.ParamName.GAMMA, node.op.ParamName.BETA
                ]):
                    continue
                if xgraph.get_op_by_name(param_tensor.name):
                    continue
                # print(f"{node.name}: {param_tensor.name}, {id(param_tensor)}")
                data = np.copy(param_tensor.data)
                if node.op.type in [
                        NNDCT_OP.CONVTRANSPOSE2D,
                        NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D
                ] and param_type == node.op.ParamName.WEIGHTS:
                    # OHWI -> OH'W'I reverse the order of ele in both h and w axis
                    data = np.flip(data, (1, 2))
                    data = np.ascontiguousarray(data)
                elif node.op.type in [
                        NNDCT_OP.CONVTRANSPOSE3D,
                        NNDCT_OP.DEPTHWISE_CONVTRANSPOSE3D
                ] and param_type == node.op.ParamName.WEIGHTS:
                    # OHWDI -> OH'W'D'I reverse the order of ele in both h and w axis
                    data = np.flip(data, (1, 2, 3))
                    data = np.ascontiguousarray(data)
                try:
                    xgraph.create_fixed_const_op(name=param_tensor.name,
                                                 data=data,
                                                 quant_info=quant_config_info)
                except Exception as e:
                    raise AddXopError(param_tensor.name, 'const', str(e))

        custom2xir = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_TO_XIR_LIST)
        if custom2xir:
            for op_type in custom2xir:
                NNDCTIR2XIR_CONVERTOR[op_type] = to_xir(op_type)

        for node in compile_graph.nodes:
            if node.op.type == NNDCT_OP.RETURN:
                continue
            # print("convert...:", node.op.type, node.name, node.in_quant_part)
            # import sys
            # sys.stdout.flush()
            try:
                NNDCTIR2XIR_CONVERTOR.get(node.op.type,
                                          custom_xop)(xgraph, node,
                                                      quant_config_info)

            except Exception as e:
                raise AddXopError(node.name, node.op.type, str(e))

        if output_file_name:
            if quant_config_info is None:
                output_file_name += '_float'
            else:
                output_file_name += '_int'

            xgraph.export_to_xmodel(output_file_name)

        return xgraph
Example #6
0
 def forward(self, *args):
     caller_map = GLOBAL_MAP.get_ele(NNDCT_KEYS.NODE_CALLER_MAP)
     output = caller_map[self.node.name](*args)
     output = quantize_tensors([output], self.node)[0]
     return output
Example #7
0
    def layout_tranform(self):
        """layout_transform TORCH(NCHW) -> XIR(NHWC)"""

        custom2xir = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_TO_XIR_LIST)
        if custom2xir is None:
            custom2xir = []

        def _find_swim_order(ndim):
            return {
                2: [0, 1],
                3: [0, 2, 1],
                4: [0, 2, 3, 1],
                5: [0, 3, 4, 2, 1]
            }[ndim]

        def _find_sink_order(ndim):
            return {
                2: [0, 1],
                3: [0, 2, 1],
                4: [0, 3, 1, 2],
                5: [0, 4, 3, 1, 2]
            }[ndim]

        def _is_dim_transparent(node):
            return node.in_tensors[0].ndim and node.out_tensors[
                0].ndim and node.in_tensors[0].ndim == node.out_tensors[0].ndim

        def _is_shape_transparent(node):
            return node.in_tensors[0].shape and node.out_tensors[
                0].shape and node.in_tensors[0].shape == node.out_tensors[
                    0].shape

        def _have_special_layout(node):
            return node.out_tensors[0].ndim and node.out_tensors[0].ndim >= 3

        def _is_custom_op(node):
            return isinstance(
                node.op, base_op.CustomOp) and node.op.type not in custom2xir

        def _is_permute_op(node):
            return isinstance(node.op, base_op.Permute)

        def _is_terminate_op(node):
            return node.op.type == NNDCT_OP.RETURN

        implicit_ops = [
            NNDCT_OP.CONV2D, NNDCT_OP.DEPTHWISE_CONV2D,
            NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D, NNDCT_OP.CONVTRANSPOSE2D,
            NNDCT_OP.MAX_POOL, NNDCT_OP.AVG_POOL, NNDCT_OP.ADAPTIVEAVGPOOL2D,
            NNDCT_OP.INTERPOLATE, NNDCT_OP.UP_SAMPLING, NNDCT_OP.RESIZE,
            NNDCT_OP.BATCH_NORM, NNDCT_OP.MAX_POOL1D, NNDCT_OP.CONV1D,
            NNDCT_OP.CONV3D, NNDCT_OP.DEPTHWISE_CONV3D,
            NNDCT_OP.DEPTHWISE_CONVTRANSPOSE3D, NNDCT_OP.CONVTRANSPOSE3D,
            NNDCT_OP.PIXEL_SHUFFLE, NNDCT_OP.PIXEL_UNSHUFFLE,
            NNDCT_OP.RESIZE_3D, NNDCT_OP.RESIZE_NEAREST_3D, NNDCT_OP.REORG,
            NNDCT_OP.CORRELATION1D_ELEMWISE, NNDCT_OP.CORRELATION2D_ELEMWISE,
            NNDCT_OP.COST_VOLUME
        ]

        special_ops_fn = {
            NNDCT_OP.RESHAPE: shape_attr_transform_fn,
            NNDCT_OP.CONCAT: axis_attr_transform_fn,
            NNDCT_OP.STRIDED_SLICE: slice_attr_transform_fn,
            NNDCT_OP.SUM: reduce_op_attr_transform_fn,
            NNDCT_OP.MAX: reduce_op_attr_transform_fn,
            NNDCT_OP.MEAN: reduce_op_attr_transform_fn,
            NNDCT_OP.SHAPE: axis_attr_transform_fn,
            NNDCT_OP.SOFTMAX: axis_attr_transform_fn,
            NNDCT_OP.ZEROS: shape_attr_transform_fn,
        }

        # collect insert point for transpose
        insert_pos = []
        for node in self._dev_graph.nodes:
            if node.op.type in implicit_ops:
                insert_pos.append(node)

        swim_transpose = defaultdict(list)
        swim_in_transpose = defaultdict(list)
        sink_transpose = defaultdict(list)

        for node in insert_pos:
            tranpose_out_order = tuple(
                _find_swim_order(node.out_tensors[0].ndim))
            swim_transpose[tranpose_out_order].append(node)
            tranpose_in_order = tuple(_find_swim_order(
                node.in_tensors[0].ndim))
            swim_in_transpose[node] = tranpose_in_order
            tranpose_out_order = tuple(
                _find_sink_order(node.out_tensors[0].ndim))
            sink_transpose[tranpose_out_order].append(node)

        nodes_need_to_remove = []
        transpose_insert_between_swim = defaultdict(list)
        visited = []
        # swim_transpose_order, nodes = next(iter(swim_transpose.items()))
        for swim_transpose_order, nodes in swim_transpose.items():
            for insert_node in nodes:
                q = deque()
                q.append(insert_node)
                visited.append(insert_node)
                insert_node.transpose_out_order = swim_transpose_order
                insert_node.transpose_in_order = swim_in_transpose[insert_node]
                while len(q) > 0:
                    node = q.popleft()
                    for pn in self._dev_graph.parents(node):
                        if pn not in visited:

                            if not _have_special_layout(
                                    pn) or pn.op.type in implicit_ops:
                                continue

                            elif pn.op.type in [
                                    NNDCT_OP.INPUT, NNDCT_OP.QUANT_STUB,
                                    NNDCT_OP.CONST, NNDCT_OP.ZEROS
                            ] or _is_dim_transparent(pn) and (
                                    not _is_permute_op(pn)) and (
                                        not _is_custom_op(pn)):
                                pn.transpose_out_order = node.transpose_in_order
                                pn.transpose_in_order = pn.transpose_out_order
                                if pn.op.type in special_ops_fn:
                                    special_ops_fn[pn.op.type](
                                        pn, pn.transpose_out_order)
                                q.append(pn)
                                visited.append(pn)

                            else:
                                # pn.transpose_out_order = [0, 2, 3, 1]
                                transpose_insert_between_swim[
                                    swim_transpose_order].append((pn, node))

        index = 0
        for transpose_order, node_pairs in transpose_insert_between_swim.items(
        ):
            for pn, cn in node_pairs:
                node_name = "_".join([pn.name, "swim_transpose", f"{index}"])
                op = base_op.Permute(NNDCT_OP.PERMUTE)
                new_node = Node(node_name,
                                op=op,
                                dtype=pn.dtype,
                                in_quant_part=pn.in_quant_part)
                new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                       list(transpose_order))
                self._dev_graph.insert_node_between_nodes(new_node, pn, cn)
                nodes_need_to_remove.append(new_node)
                index += 1

        if transpose_insert_between_swim:
            self._dev_graph.reconnect_nodes()

        # debug
        # print("#####swim######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)

        transpose_insert_between_sink = defaultdict(list)
        visited = []
        for node in self._dev_graph.nodes:
            if node.transpose_out_order:
                nodes = sink_transpose[tuple(
                    _find_sink_order(len(node.transpose_out_order)))]
                if node not in nodes:
                    nodes.append(node)

        for sink_transpose_order, nodes in sink_transpose.items():
            for insert_node in nodes:
                if insert_node not in visited:
                    q = deque()
                    q.append(insert_node)
                    visited.append(insert_node)
                    while len(q) > 0:
                        node = q.popleft()
                        for cn in self._dev_graph.children(node):
                            if cn not in visited:
                                if cn.op.type in implicit_ops or _is_terminate_op(
                                        cn):
                                    continue
                                elif cn.op.type == NNDCT_OP.SHAPE:
                                    visited.append(cn)
                                    if node.transpose_out_order:
                                        special_ops_fn[cn.op.type](
                                            cn, node.transpose_out_order)
                                        continue
                                elif cn.transpose_out_order:
                                    q.append(cn)
                                    visited.append(cn)
                                elif _is_dim_transparent(cn) and (
                                        not _is_permute_op(cn)) and (
                                            not _is_custom_op(cn)):
                                    cn.transpose_in_order = node.transpose_out_order
                                    cn.transpose_out_order = cn.transpose_in_order
                                    q.append(cn)
                                    visited.append(cn)
                                    if cn.op.type in special_ops_fn:
                                        special_ops_fn[cn.op.type](
                                            cn, cn.transpose_out_order)
                                else:
                                    transpose_insert_between_sink[
                                        sink_transpose_order].append(
                                            (node, cn))

        index = 0
        for transpose_order, node_pairs in transpose_insert_between_sink.items(
        ):
            for pn, cn in node_pairs:

                node_name = "_".join([pn.name, "sink_transpose", f"{index}"])
                op = base_op.Permute(NNDCT_OP.PERMUTE)
                new_node = Node(node_name,
                                op=op,
                                dtype=pn.dtype,
                                in_quant_part=cn.in_quant_part)
                new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                       list(transpose_order))
                self._dev_graph.insert_node_between_nodes(new_node, pn, cn)

                nodes_need_to_remove.append(new_node)
                index += 1

        if transpose_insert_between_sink:
            self._dev_graph.reconnect_nodes()

        # debug
        # print("#####sink######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)
        neighbor_broadcast = {}
        for node in self._dev_graph.nodes:
            if len(node.in_nodes) <= 1 or node in implicit_ops:
                continue
            if all([
                    node.transpose_out_order is None
                    for node in self._dev_graph.parents(node)
            ]) or all([
                    node.transpose_out_order is not None
                    for node in self._dev_graph.parents(node)
            ]):
                continue
            #if node.out_tensors[0].dtype != "float32":
            #  continue
            transpose_order = None
            for pn in self._dev_graph.parents(node):
                transpose_order = pn.transpose_out_order
                if transpose_order is not None:
                    break

            neighbor_broadcast[node] = transpose_order

        have_neighbors = False
        for node, transpose_order in neighbor_broadcast.items():
            index = 0
            for pn in self._dev_graph.parents(node):
                if pn.transpose_out_order is None and pn.out_tensors[
                        0].ndim and node.out_tensors[0].ndim and pn.out_tensors[
                            0].ndim == node.out_tensors[0].ndim:
                    # pn.transpose_out_order = node.transpose_out_order
                    node_name = "_".join(
                        [node.name, "neighbor_transpose", f"{index}"])
                    op = base_op.Permute(NNDCT_OP.PERMUTE)
                    new_node = Node(node_name,
                                    op=op,
                                    dtype=node.dtype,
                                    in_quant_part=pn.in_quant_part)
                    new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                           list(transpose_order))
                    self._dev_graph.insert_node_between_nodes(
                        new_node, pn, node)

                    index += 1

                    nodes_need_to_remove.append(new_node)
                    have_neighbors = True

        if have_neighbors:
            self._dev_graph.reconnect_nodes()

        # Debug
        # print("####neightbor######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)
        # remove consecutive transpose

        def merge_father_and_child(node, visited, transpose_group,
                                   reserverd_nodes):
            visited.append(node)
            if _is_permute_op(node):
                if node.out_nodes and all([
                        _is_permute_op(cn)
                        for cn in self._dev_graph.children(node)
                ]):
                    transpose_group.append(node)
                else:
                    transpose_group.append(node)

                    order = []
                    reserved_trans = None
                    for trans in transpose_group:
                        if trans not in nodes_need_to_remove:
                            reserved_trans = trans

                        if not order:
                            order = trans.node_attr(trans.op.AttrName.ORDER)
                        else:
                            new_order = len(order) * [None]
                            tmp_order = trans.node_attr(
                                trans.op.AttrName.ORDER)
                            for i in range(len(order)):
                                t_i = tmp_order[i]
                                new_order[i] = order[t_i]
                            order = new_order

                    if reserved_trans is None:
                        reserved_trans = transpose_group[-1]

                    reserved_trans.set_node_attr(
                        reserved_trans.op.AttrName.ORDER, order)
                    reserverd_nodes.append(reserved_trans)

                    transpose_group.clear()

            for cn in self._dev_graph.children(node):
                if cn not in visited:
                    merge_father_and_child(cn, visited, transpose_group,
                                           reserverd_nodes)

        def merge_brothers(reserverd_nodes):
            remove_nodes = []
            for node in self._dev_graph.nodes:
                if len(node.out_nodes) > 1 and all([
                        _is_permute_op(cn)
                        for cn in self._dev_graph.children(node)
                ]):
                    need_merge = True
                    order = None
                    for trans_node in self._dev_graph.children(node):
                        if order is not None:
                            if order != trans_node.node_attr(
                                    trans_node.op.AttrName.ORDER):
                                need_merge = False
                                break
                        else:
                            order = trans_node.node_attr(
                                trans_node.op.AttrName.ORDER)

                    if need_merge:
                        reserverd_node = None
                        for trans_node in self._dev_graph.children(node):
                            if trans_node not in nodes_need_to_remove:
                                reserverd_node = trans_node

                        if reserverd_node is None:
                            reserverd_node = self._dev_graph.children(node)[0]

                        for trans_node in self._dev_graph.children(node):
                            if trans_node is not reserverd_node and trans_node in reserverd_nodes:
                                remove_nodes.append(trans_node)

                                out_tensor = trans_node.out_tensors[0]
                                out_tensor.replace_uses_with(
                                    reserverd_node.out_tensors[0])

            for node in remove_nodes:
                node.destroy()

            if remove_nodes:
                self._dev_graph.reconnect_nodes()

        source_nodes = []
        for node in self._dev_graph.nodes:
            if not node.in_tensors:
                source_nodes.append(node)

        transpose_group = []
        reserverd_nodes = []
        visited = []
        for source in source_nodes:
            merge_father_and_child(source, visited, transpose_group,
                                   reserverd_nodes)

        nodes_need_to_remove = [
            node for node in nodes_need_to_remove
            if node not in reserverd_nodes
        ]

        for node in reserverd_nodes:
            order = node.node_attr(node.op.AttrName.ORDER)
            keep_order = True
            if any([index != dim for index, dim in enumerate(order)]):
                keep_order = False
            if keep_order:
                nodes_need_to_remove.append(node)

        for node in nodes_need_to_remove:
            self._dev_graph.remove_node(node)

        merge_brothers(reserverd_nodes)

        # debug
        # print("#####finalize######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)

        def delete_transpose_of_correlation(self):
            nodes_need_to_delete_for_special_ops = []
            nodes_need_to_insert_aster_special_ops = []
            nodes_need_to_merge_for_special_ops = []
            for node in self._dev_graph.nodes:
                if node.op.type == NNDCT_OP.MEAN and not node.node_attr(
                        node.op.AttrName.KEEP_DIMS
                ) and self._dev_graph.parents(node):
                    pn = self._dev_graph.parents(node)[0]
                    if pn.in_tensors and _is_permute_op(
                            pn) and self._dev_graph.parents(pn):
                        gpn = self._dev_graph.parents(pn)[0]
                        if gpn.op.type in [
                                NNDCT_OP.CORRELATION1D_ELEMWISE,
                                NNDCT_OP.CORRELATION2D_ELEMWISE
                        ] and node.out_tensors[0].ndim and gpn.out_tensors[
                                0].ndim == 5 and node.out_tensors[0].ndim == 4:

                            nodes_need_to_delete_for_special_ops.append(pn)

                            node.transpose_in_order = tuple(
                                _find_swim_order(5))
                            node.transpose_out_order = tuple(
                                _find_swim_order(4))
                            special_ops_fn[node.op.type](
                                node, node.transpose_in_order)

                            nodes_need_to_insert_aster_special_ops.append(node)
            index = 0
            for node in nodes_need_to_insert_aster_special_ops:
                cn = self._dev_graph.children(node)[0]
                node_name = "_".join([node.name, "sink_transpose", f"{index}"])
                op = base_op.Permute(NNDCT_OP.PERMUTE)
                new_node = Node(node_name,
                                op=op,
                                dtype=node.dtype,
                                in_quant_part=node.in_quant_part)
                new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                       tuple(_find_sink_order(4)))
                self._dev_graph.insert_node_between_nodes(new_node, node, cn)
                nodes_need_to_merge_for_special_ops.append(new_node)
                index += 1

            for node in nodes_need_to_delete_for_special_ops:
                self._dev_graph.remove_node(node)

            source_nodes = []
            for node in self._dev_graph.nodes:
                if not node.in_tensors:
                    source_nodes.append(node)

            transpose_group = []
            reserverd_nodes = []
            visited = []
            for source in nodes_need_to_merge_for_special_ops:
                merge_father_and_child(source, visited, transpose_group,
                                       reserverd_nodes)

            nodes_need_to_merge_for_special_ops = [
                node for node in nodes_need_to_merge_for_special_ops
                if node not in reserverd_nodes
            ]

            for node in reserverd_nodes:
                order = node.node_attr(node.op.AttrName.ORDER)
                keep_order = True
                if any([index != dim for index, dim in enumerate(order)]):
                    keep_order = False
                if keep_order:
                    nodes_need_to_merge_for_special_ops.append(node)

            for node in nodes_need_to_merge_for_special_ops:
                self._dev_graph.remove_node(node)

            merge_brothers(reserverd_nodes)

        delete_transpose_of_correlation(self)