示例#1
0
def squeeze(xgraph: XGraph, node: Node,
            quant_config: NndctQuantInfo) -> NoReturn:
    if node.in_tensors[0].ndim == 4 and len(
            node.node_attr(node.op.AttrName.DIMS)) == 1:
        attrs: Dict[str, Any] = {}
        attrs["order"] = [0, 3, 1, 2]

        # resume dimension to NCHW
        input_ops: Dict[str, List[Op]] = {}
        input_list = []
        for input in node.in_nodes:
            input_op = xgraph.get_op_by_name(input)
            input_list.append(input_op)
        input_ops["input"] = input_list
        xgraph.create_fixed_normal_op(node.name + "_i0",
                                      "transpose",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)

        attrs: Dict[str, Any] = {}
        dim = node.node_attr(node.op.AttrName.DIMS)[0]
        dim = transformed_axis("NHWC", "NCHW", ndim=4, dim=dim)
        attrs["axis"] = [dim]
        input_ops: Dict[str, List[Op]] = {}
        input_ops["input"] = [xgraph.get_op_by_name(node.name + "_i0")]
        xgraph.create_fixed_normal_op(node.name,
                                      "squeeze",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)
    else:
        to_xir("squeeze")(xgraph, node, quant_config)
示例#2
0
def default_xop(xop_type: str, xgraph: XGraph, node: Node,
                quant_config: NndctQuantInfo) -> NoReturn:

    input_ops: Dict[str, List["xir.Op"]] = {}
    if node.has_bound_params():
        for param_name, param_tensor in node.op.params.items():
            param = xgraph.get_op_by_name(param_tensor.name)
            input_ops[param_name.name.lower()] = [param]

    input_list = []
    for input in node.in_tensors:
        if node.has_bound_params() and input.is_param_tensor():
            continue
        elif input.is_param_tensor():
            input_op = xgraph.get_op_by_name(input.name)
        else:
            input_op = xgraph.get_op_by_name(input.node.name)
        input_list.append(input_op)

    input_ops["input"] = xgraph.create_input_fix_ops(input_list, node.name,
                                                     quant_config)

    attrs = _get_xir_attr_from_node(node)
    xgraph.create_fixed_normal_op(node.name,
                                  xop_type,
                                  quant_config,
                                  attrs=attrs,
                                  input_ops=input_ops)
示例#3
0
def conv_transpose_3d(xgraph: XGraph, node: Node,
                      quant_config: NndctQuantInfo) -> NoReturn:
    attrs = _get_xir_attr_from_node(node)
    attrs['kernel'] = attrs['kernel'][::-1]
    attrs['stride'] = attrs['stride'][::-1]
    attrs['dilation'] = attrs['dilation'][::-1]
    attrs['pad'] = list(
        itertools.chain.from_iterable([[pad] * 2
                                       for pad in attrs['pad'][::-1]]))
    print(attrs)
    input_ops: Dict[str, List["xir.Op"]] = {}
    if node.has_bound_params():
        for param_name, param_tensor in node.op.params.items():
            param = xgraph.get_op_by_name(param_tensor.name)
            input_ops[param_name.name.lower()] = [param]

    input_list = []
    for input in node.in_tensors:
        if node.has_bound_params() and input.is_param_tensor():
            continue
        elif input.is_param_tensor():
            input_op = xgraph.get_op_by_name(input.name)
        else:
            input_op = xgraph.get_op_by_name(input.node.name)
        input_list.append(input_op)

    input_ops["input"] = xgraph.create_input_fix_ops(input_list, node.name,
                                                     quant_config)

    xgraph.create_fixed_normal_op(node.name,
                                  "transposed_conv3d",
                                  quant_config,
                                  attrs=attrs,
                                  input_ops=input_ops)
示例#4
0
def permute_invar_op(xop_type, xgraph: XGraph, node: Node,
                     quant_config: NndctQuantInfo) -> NoReturn:
    if not node.node_attr(node.op.AttrName.KEEP_DIMS) \
      and node.in_tensors[0].ndim == 4 \
      and len(node.node_attr(node.op.AttrName.DIMS)) == 1 \
      and node.node_attr(node.op.AttrName.DIMS)[0] != 3:
        layout = ["N", "H", "W", "C"]
        del layout[node.node_attr(node.op.AttrName.DIMS)[0]]
        # create mean which keep_dim is True
        attrs: Dict[str, Any] = {}
        attrs["axis"] = node.node_attr(node.op.AttrName.DIMS)
        attrs["keep_dims"] = True
        input_ops: Dict[str, List[Op]] = {}
        input_list = []
        for input in node.in_nodes:
            input_op = xgraph.get_op_by_name(input)
            input_list.append(input_op)
        input_ops["input"] = xgraph.create_input_fix_ops(
            input_list, node.name, quant_config)
        xgraph.create_fixed_normal_op(node.name + "_i0",
                                      xop_type,
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)

        attrs: Dict[str, Any] = {}
        if layout == ["N", "H", "C"]:
            attrs["order"] = [0, 3, 1, 2]
        else:
            attrs["order"] = [0, 3, 2, 1]

        # resume dimension to NCHW
        input_ops: Dict[str, List[Op]] = {}
        input_ops["input"] = [xgraph.get_op_by_name(node.name + "_i0")]
        xgraph.create_fixed_normal_op(node.name + "_i1",
                                      "transpose",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)

        attrs: Dict[str, Any] = {}
        if layout == ["N", "H", "C"]:
            attrs["axis"] = [3]
        else:
            attrs["axis"] = [2]
        input_ops: Dict[str, List[Op]] = {}
        input_ops["input"] = [xgraph.get_op_by_name(node.name + "_i1")]
        xgraph.create_fixed_normal_op(node.name,
                                      "squeeze",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)
    else:
        to_xir(xop_type)(xgraph, node, quant_config)
示例#5
0
def avgpool(xgraph: XGraph, node: Node,
            quant_config: NndctQuantInfo) -> NoReturn:

    needScale = False
    scale = 1.0
    if node.node_attr(node.op.AttrName.KERNEL) == [3, 3]:
        needScale = True
        scale = 9.0 * 7.0 / 64.0
    elif node.node_attr(node.op.AttrName.KERNEL) == [5, 5]:
        needScale = True
        scale = 25.0 * 10.0 / 256.0
    elif node.node_attr(node.op.AttrName.KERNEL) in [[6, 6], [3, 6], [6, 3]]:
        needScale = True
        scale = 36.0 * 7.0 / 256.0
    elif node.node_attr(node.op.AttrName.KERNEL) == [7, 7]:
        needScale = True
        scale = 49.0 * 21.0 / 1024.0
    elif node.node_attr(node.op.AttrName.KERNEL) == [14, 14]:
        needScale = True
        scale = 196.0 * 21.0 / 4096.0

    if needScale:
        attrs = _get_attr_from_node(node)
        # attrs: Dict[str, Any] = {}
        # for attr_name, attr_value in node.op.attrs.items():
        #   attrs[attr_name.value] = _Converter.to_xir_attr_value(attr_name.value, attr_value.value)

        input_ops: Dict[str, List[Op]] = {}
        input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])]
        input_ops["input"] = xgraph.create_input_fix_ops(
            input_ops["input"], node.name, quant_config)
        xgraph.create_fixed_normal_op(node.name + "_i0",
                                      "avgpool2d",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)

        scale = [scale]
        xgraph.create_fixed_const_op(name=node.name + "_i1",
                                     data=np.array(scale, dtype=np.float32),
                                     quant_info=quant_config)

        input_ops: Dict[str, List[Op]] = {}
        input_ops["input"] = [
            xgraph.get_op_by_name(node.name + "_i0"),
            xgraph.get_op_by_name(node.name + "_i1")
        ]
        xgraph.create_fixed_normal_op(node.name,
                                      "mul",
                                      quant_config,
                                      input_ops=input_ops)
    else:
        to_xir("avgpool2d")(xgraph, node, quant_config)
示例#6
0
def resize(xgraph: XGraph, node: Node,
           quant_config: NndctQuantInfo) -> NoReturn:
    """
  resize is a macro operator, including concat , resize
  """
    attrs: Dict[str, Any] = {}
    # attrs["scale"] = node.node_attr(node.op.AttrName.SCALE)

    attrs["align_corners"] = node.node_attr(node.op.AttrName.ALIGN_CORNERS)
    attrs["half_pixel_centers"] = node.node_attr(
        node.op.AttrName.HALF_PIXEL_CENTERS)
    attrs["mode"] = node.node_attr(node.op.AttrName.MODE)
    # attrs["mode"] = {0: "NEAREST", 3: "BILINEAR"}.get(attrs["mode"])
    size = node.node_attr(node.op.AttrName.SIZE)
    scale = node.node_attr(node.op.AttrName.SCALE)
    # if size[0] == 0 and size[1] == 0:
    if all([s == 0 for s in size]):
        attrs["scale"] = scale
        input_ops: Dict[str, List["xir.Op"]] = {}
        input_list = []
        for input in node.in_nodes:
            input_op = xgraph.get_op_by_name(input)
            input_list.append(input_op)
        input_ops["input"] = xgraph.create_input_fix_ops(
            input_list, node.name, quant_config)
        xgraph.create_fixed_normal_op(node.name,
                                      "resize",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)
    else:
        sub_pack_op, pack_list = _pack(xgraph, node, "size", size,
                                       quant_config)
        input_ops: Dict[str, List["xir.Op"]] = {}
        input_ops["size"] = [sub_pack_op]
        input_list = []
        for input in node.in_nodes:
            input_op = xgraph.get_op_by_name(input)
            input_list.append(input_op)
        input_ops["input"] = input_list
        input_ops["input"] = [
            op for op in input_ops["input"]
            if op.get_name() not in [i.get_name() for i in pack_list]
        ]
        input_ops["input"] = xgraph.create_input_fix_ops(
            input_ops["input"], node.name, quant_config)
        xgraph.create_fixed_normal_op(node.name,
                                      "resize",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)
示例#7
0
    def _generate_scale_node(node, scale_param, node_idx):

        node_name = node.name
        tensor_out = node.out_tensors[0]
        out_tensor_shape = tensor_out.shape

        scale_op = TorchBaseOperation(NNDCT_OP.CHANNEL_SCALE,
                                      "Channel_Scale",
                                      force_to_primitive=True)
        scale_op.set_config("channel_scale", scale_param)
        scale_op.set_config("input", tensor_out)

        if node.op.type == NNDCT_OP.CHANNEL_SCALE:
            scale_node_name = node_name[:node_name.rfind(".")] + ".2"
        else:
            scale_node_name = "/".join([node_name, "channel_scale.1"])

        scale_node = Node(scale_node_name,
                          op=scale_op,
                          dtype="float32",
                          idx=node_idx,
                          in_quant_part=False)

        out_tensor = Tensor(name=f"{scale_node_name}.0",
                            node=scale_node,
                            shape=out_tensor_shape,
                            dtype="float32",
                            layout=tensor_out.layout)
        scale_node.out_tensors.append(out_tensor)

        scale_node.in_tensors.append(tensor_out)
        return scale_node
示例#8
0
def shape(xgraph: XGraph, node: Node,
          quant_config: NndctQuantInfo) -> NoReturn:
    r""" nndct shape is a macro operator, including shape, stridedslice 
      """
    # raise NotImplementedError("shape")
    input_list = []
    shape_input_ops: Dict[str, List[Op]] = {}
    for input in node.in_nodes:
        input_op = xgraph.get_op_by_name(input)
        input_list.append(input_op)
    shape_input_ops["input"] = input_list

    sub_op_shape = xgraph.create_fixed_normal_op(node.name + "_i0",
                                                 "shape",
                                                 quant_config,
                                                 input_ops=shape_input_ops)

    attrs: Dict[str, Any] = {}
    strided_slice_input_ops: Dict[str, List[Op]] = {}
    strided_slice_input_ops["input"] = [sub_op_shape]
    dim = node.node_attr(node.op.AttrName.AXIS)
    attrs["begin"] = [dim]
    attrs["end"] = [dim + 1]
    xgraph.create_fixed_normal_op(node.name,
                                  "strided_slice",
                                  quant_config,
                                  attrs=attrs,
                                  input_ops=strided_slice_input_ops)
示例#9
0
    def __call__(self,
                 graph,
                 node_name,
                 op,
                 num_out_tensors,
                 shape=None,
                 in_tensors=None,
                 in_quant_part=True):

        node_name = get_full_name(graph.name, node_name)
        node = Node(node_name,
                    op=op,
                    dtype="float32",
                    idx=self._idx,
                    in_quant_part=in_quant_part)
        # print(f"{node.name} quant state: {node.in_quant_part}")
        for i in range(num_out_tensors):
            tensor = Tensor(name=f"{node_name}_{i}", node=node, shape=shape)
            node.out_tensors.append(tensor)

        if in_tensors:
            for tensor in in_tensors:
                node.in_tensors.append(tensor)
        graph.add_node(node)
        self._idx += 1
示例#10
0
def zeros(xgraph: XGraph, node: Node,
          quant_config: NndctQuantInfo) -> NoReturn:
    shape = node.node_attr(node.op.AttrName.SHAPE)
    data = np.zeros(shape,
                    dtype=_Converter.to_numpy_dtype(node.out_tensors[0].dtype))
    xgraph.create_fixed_const_op(name=node.name,
                                 data=data,
                                 quant_info=quant_config)
示例#11
0
 def _gen_module_name(self, node: Node):
     module_name = TorchSymbol.MODULE_NAME_SEPERATOR.join(
         [TorchSymbol.MODULE_BASE_SYMBOL,
          str(self._global_idx)])
     node.idx = self._global_idx
     self._global_idx += 1
     self._module_names[node.name] = module_name
     return module_name
示例#12
0
def avgpool(xgraph: XGraph, node: Node,
            quant_config: NndctQuantInfo) -> NoReturn:

    scale = 1.0
    if node.node_attr(node.op.AttrName.KERNEL) == [3, 3]:
        scale = 9.0 * 7.0 / 64.0
    elif node.node_attr(node.op.AttrName.KERNEL) == [5, 5]:
        scale = 25.0 * 10.0 / 256.0
    elif node.node_attr(node.op.AttrName.KERNEL) in [[6, 6], [3, 6], [6, 3]]:
        scale = 36.0 * 7.0 / 256.0
    elif node.node_attr(node.op.AttrName.KERNEL) == [7, 7]:
        scale = 49.0 * 21.0 / 1024.0
    elif node.node_attr(node.op.AttrName.KERNEL) == [14, 14]:
        scale = 196.0 * 21.0 / 4096.0
    else:
        rec = node.node_attr(node.op.AttrName.KERNEL)[0] * node.node_attr(
            node.op.AttrName.KERNEL)[1]
        max_factor = math.ceil(math.log(rec * 128, 2))
        diff = 1.0
        multi_factor = 0.0
        shift_factor = 0.0
        for shift_factor_ in range(max_factor):
            factor = round((2**shift_factor_) / rec)
            diff_ = abs(factor / (2**shift_factor_) - 1 / rec)
            if diff_ < diff:
                multi_factor = factor
                diff = diff_
                shift_factor = shift_factor_
        scale = rec * multi_factor / (2**shift_factor)

    attrs = _get_xir_attr_from_node(node)
    # attrs: Dict[str, Any] = {}
    # for attr_name, attr_value in node.op.attrs.items():
    #   attrs[attr_name.value] = _Converter.to_xir_attr_value(attr_name.value, attr_value.value)

    input_ops: Dict[str, List["xir.Op"]] = {}
    input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])]
    input_ops["input"] = xgraph.create_input_fix_ops(input_ops["input"],
                                                     node.name, quant_config)
    xgraph.create_fixed_normal_op(node.name + "_i0",
                                  "avgpool2d",
                                  quant_config,
                                  attrs=attrs,
                                  input_ops=input_ops)

    scale = [scale]
    xgraph.create_fixed_const_op(name=node.name + "_i1",
                                 data=np.array(scale, dtype=np.float32),
                                 quant_info=quant_config)

    input_ops: Dict[str, List["xir.Op"]] = {}
    input_ops["input"] = [
        xgraph.get_op_by_name(node.name + "_i0"),
        xgraph.get_op_by_name(node.name + "_i1")
    ]
    xgraph.create_fixed_normal_op(node.name,
                                  "mul",
                                  quant_config,
                                  input_ops=input_ops)
示例#13
0
def const_xop(xgraph: XGraph, node: Node,
              quant_config: NndctQuantInfo) -> NoReturn:
    data = node.node_attr(node.op.AttrName.DATA)
    data_type = np.dtype(node.out_tensors[0].dtype)

    if not isinstance(data, list):
        data = [data]

    xgraph.create_fixed_const_op(name=node.name,
                                 data=np.array(data, dtype=data_type),
                                 quant_info=quant_config)
示例#14
0
    def _get_module_attrs_map(self, node: Node, torch_op_type: str,
                              torch_op_attrs: Dict[str, Any]):
        ordered_attrs = OrderedDict()
        attrs_template = torch_op_attrs if torch_op_attrs and torch_op_attrs != [
            'args', 'kwargs'
        ] else node.op.configs
        for name in attrs_template:
            if hasattr(node.op, name):
                ordered_attrs[name] = node.node_config(name)

        return ordered_attrs
示例#15
0
def reshape(xgraph: XGraph, node: Node,
            quant_config: NndctQuantInfo) -> NoReturn:
    r""" nndct reshape is a macro operator, including pack, reshape
      """
    # raise NotImplementedError("reshape")

    if node.in_tensors[0].ndim != 4 or node.in_tensors[
            0].layout == Tensor.Layout.NHWC:
        shape = node.node_attr(node.op.AttrName.SHAPE)
        sub_op_pack, pack_list = _pack(xgraph, node, "shape", shape,
                                       quant_config)
        input_ops: Dict[str, List[Op]] = {}
        input_ops["shape"] = [sub_op_pack]
        input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])]
        xgraph.create_fixed_normal_op(node.name,
                                      "reshape",
                                      quant_config,
                                      input_ops=input_ops)

    else:
        shape = node.node_attr(node.op.AttrName.SHAPE)
        sub_op_pack, pack_list = _pack(xgraph, node, "shape", shape,
                                       quant_config)
        attrs: Dict[str, Any] = {}
        # NHWC -> NCHW
        attrs["order"] = [0, 3, 1, 2]
        input_ops: Dict[str, List[Op]] = {}
        input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])]
        xgraph.create_fixed_normal_op(node.name + "_i0",
                                      "transpose",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)

        input_ops: Dict[str, List[Op]] = {}
        input_ops["shape"] = [sub_op_pack]
        input_ops["input"] = [xgraph.get_op_by_name(node.name + "_i0")]
        xgraph.create_fixed_normal_op(node.name,
                                      "reshape",
                                      quant_config,
                                      input_ops=input_ops)
示例#16
0
def binary_op(op_type: str, xgraph: XGraph, node: Node,
              quant_config: NndctQuantInfo):
    input, other = node.node_attr(node.op.AttrName.INPUT), node.node_attr(
        node.op.AttrName.OTHER)
    if isinstance(input, Tensor) and (not isinstance(other, Tensor)):
        operand1 = xgraph.get_op_by_name(input.node.name)
        dtype = other.dtype if isinstance(other, np.ndarray) else type(other)
        operand2 = np.ones(input.shape, dtype=dtype) * other
        operand2 = xgraph.create_const_op(f"{node.name}_other", operand2)
    else:
        operand1 = xgraph.get_op_by_name(input.node.name)
        operand2 = xgraph.get_op_by_name(other.node.name)

    input_ops: Dict[str, List["xir.Op"]] = {}
    input_ops["input"] = [operand1, operand2]
    input_ops["input"] = xgraph.create_input_fix_ops(input_ops["input"],
                                                     node.name, quant_config)
    xgraph.create_fixed_normal_op(node.name,
                                  op_type,
                                  quant_config,
                                  input_ops=input_ops)
示例#17
0
    def __call__(self, parser, raw_node, node_scope=''):
        nndct_node = Node(name=get_full_name(node_scope, raw_node.name),
                          dtype=convert_dtype(raw_node.dtype),
                          idx=raw_node.idx)

        nndct_node.raw_kind = raw_node.kind
        nndct_node.schema = raw_node.schema
        nndct_node.is_custom_extension = raw_node.is_custom_pyop
        nndct_node.caller = raw_node.pyobj

        blob_tensor_convertor = TensorConvertor()
        for op in raw_node.outputs:
            nndct_tensor = blob_tensor_convertor(node_scope, op)
            nndct_tensor.node = nndct_node
            nndct_node.out_tensors.append(nndct_tensor)
            parser.visited_blob_tensors[op.name] = nndct_tensor

        for ip in raw_node.flatten_inputs:
            if ip.name in parser.visited_blob_tensors:
                nndct_node.in_tensors.append(
                    parser.visited_blob_tensors[ip.name])
            elif ip.name in parser.visited_param_tensors:
                parser.node_params[nndct_node].append(
                    parser.visited_param_tensors[ip.name])
                nndct_node.in_tensors.append(
                    parser.visited_param_tensors[ip.name])

        if not raw_node.inputs:
            parser.node_input_args[nndct_node].append(
                nndct_node.out_tensors[0])
        else:
            parser.node_input_args[nndct_node].extend(
                [parser.get_nndct_value(i) for i in raw_node.inputs])

        return nndct_node
示例#18
0
 def _get_init_param_str(self, node: Node) -> List[str]:
     str_list = []
     if node.has_bound_params():
         return str_list
     for param_type, param_tensor in node.op.params.items():
         if param_tensor.name not in self._tensor_output_map:
             param_name = param_type.value
             param_shape = tuple(param_tensor.shape)
             param_init_str = f"self.{param_name} = torch.nn.parameter.Parameter(torch.Tensor{param_shape})"
             str_list.append(param_init_str)
             self._tensor_output_map[
                 param_tensor.name] = f"self.{param_name}"
     return str_list
示例#19
0
def reshape(xgraph: XGraph, node: Node,
            quant_config: NndctQuantInfo) -> NoReturn:
    r""" nndct reshape is a macro operator, including pack, reshape
      """
    shape = node.node_attr(node.op.AttrName.SHAPE)
    sub_op_pack, pack_list = _pack(xgraph, node, "shape", shape, quant_config)
    input_ops: Dict[str, List["xir.Op"]] = {}
    input_ops["shape"] = [sub_op_pack]
    input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])]
    xgraph.create_fixed_normal_op(node.name,
                                  "reshape",
                                  quant_config,
                                  input_ops=input_ops)
示例#20
0
def reduction_mean(xgraph: XGraph, node: Node,
                   quant_config: NndctQuantInfo) -> NoReturn:

    attrs = _get_xir_attr_from_node(node)

    input_ops: Dict[str, List[Op]] = {}
    input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])]
    input_ops["input"] = xgraph.create_input_fix_ops(input_ops["input"],
                                                     node.name, quant_config)
    if len(node.node_attr(node.op.AttrName.DIMS)) == 1:
        xgraph.create_fixed_normal_op(node.name + "_i0",
                                      "reduction_mean",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)
        scale = calculate_op_scale(
            node.in_tensors[0].shape[node.node_attr(node.op.AttrName.DIMS)[0]],
            node)
        scale = [scale]
        xgraph.create_fixed_const_op(name=node.name + "_i1",
                                     data=np.array(scale, dtype=np.float32),
                                     quant_info=quant_config)

        input_ops: Dict[str, List[Op]] = {}
        input_ops["input"] = [
            xgraph.get_op_by_name(node.name + "_i0"),
            xgraph.get_op_by_name(node.name + "_i1")
        ]
        xgraph.create_fixed_normal_op(node.name,
                                      "mul",
                                      quant_config,
                                      input_ops=input_ops)
    else:
        xgraph.create_fixed_normal_op(node.name,
                                      "reduction_mean",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)
示例#21
0
def const_xop(xgraph: XGraph, node: Node,
              quant_config: NndctQuantInfo) -> NoReturn:
    data = node.node_attr(node.op.AttrName.DATA)
    data_type = np.dtype(node.out_tensors[0].dtype)
    data_type = np.float32 if data_type == np.float64 else data_type
    if not isinstance(data, list):
        data = [data]

    data = np.array(data, dtype=data_type)
    data = np.transpose(
        data, node.transpose_out_order) if node.transpose_out_order else data
    xgraph.create_fixed_const_op(name=node.name,
                                 data=data,
                                 quant_info=quant_config)
示例#22
0
  def __call__(self, parser, raw_graph, raw_node):
    nndct_node = Node(
        name=get_full_name(raw_graph.name, raw_node.name),
        dtype=convert_dtype(raw_node.dtype),
        idx=raw_node.idx)

    blob_tensor_convertor = TensorConvertor()
    for op in raw_node.outputs:
      nndct_tensor = blob_tensor_convertor(raw_graph.name, op)
      nndct_tensor.node = nndct_node
      nndct_node.out_tensors.append(nndct_tensor)

    for ip in raw_node.flatten_inputs:
      if ip.name not in raw_graph.param_names(
      ) and parser.get_blob_tensor_by_name(ip.name):
        nndct_node.in_tensors.append(parser.get_blob_tensor_by_name(ip.name))
    return nndct_node
示例#23
0
    def _init_op_and_attrs_str(self, node: Node) -> str:
        torch_op_type = py_utils.get_torch_op_type(node.op.type)
        torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)
        op_name = py_utils.get_defined_quant_module(
            torch_op_type
        ) if not node.has_custom_op() or torch_op_attr.op_class_type in [
            TorchOpClassType.PRIMITIVE
        ] else ''
        op_name, is_defined_op = (op_name, True) if op_name else (".".join(
            [TorchSymbol.MODULE_PREFIX, "Module"]), False)
        attrs_str = ""
        if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE:
            attrs_str = self._to_map_str(
                self._get_module_attrs_map(node, torch_op_type,
                                           torch_op_attr.attrs))

        if not is_defined_op:
            attrs_str = f"'{node.op.type}',{attrs_str}" if attrs_str else f"'{node.op.type}'"

        return op_name, attrs_str
示例#24
0
    def layout_tranform(self):
        """layout_transform TORCH(NCHW) -> XIR(NHWC)"""

        custom2xir = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_TO_XIR_LIST)
        if custom2xir is None:
            custom2xir = []

        def _find_swim_order(ndim):
            return {
                2: [0, 1],
                3: [0, 2, 1],
                4: [0, 2, 3, 1],
                5: [0, 3, 4, 2, 1]
            }[ndim]

        def _find_sink_order(ndim):
            return {
                2: [0, 1],
                3: [0, 2, 1],
                4: [0, 3, 1, 2],
                5: [0, 4, 3, 1, 2]
            }[ndim]

        def _is_dim_transparent(node):
            return node.in_tensors[0].ndim and node.out_tensors[
                0].ndim and node.in_tensors[0].ndim == node.out_tensors[0].ndim

        def _is_shape_transparent(node):
            return node.in_tensors[0].shape and node.out_tensors[
                0].shape and node.in_tensors[0].shape == node.out_tensors[
                    0].shape

        def _have_special_layout(node):
            return node.out_tensors[0].ndim and node.out_tensors[0].ndim >= 3

        def _is_custom_op(node):
            return isinstance(
                node.op, base_op.CustomOp) and node.op.type not in custom2xir

        def _is_permute_op(node):
            return isinstance(node.op, base_op.Permute)

        def _is_terminate_op(node):
            return node.op.type == NNDCT_OP.RETURN

        implicit_ops = [
            NNDCT_OP.CONV2D, NNDCT_OP.DEPTHWISE_CONV2D,
            NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D, NNDCT_OP.CONVTRANSPOSE2D,
            NNDCT_OP.MAX_POOL, NNDCT_OP.AVG_POOL, NNDCT_OP.ADAPTIVEAVGPOOL2D,
            NNDCT_OP.INTERPOLATE, NNDCT_OP.UP_SAMPLING, NNDCT_OP.RESIZE,
            NNDCT_OP.BATCH_NORM, NNDCT_OP.MAX_POOL1D, NNDCT_OP.CONV1D,
            NNDCT_OP.CONV3D, NNDCT_OP.DEPTHWISE_CONV3D,
            NNDCT_OP.DEPTHWISE_CONVTRANSPOSE3D, NNDCT_OP.CONVTRANSPOSE3D,
            NNDCT_OP.PIXEL_SHUFFLE, NNDCT_OP.PIXEL_UNSHUFFLE,
            NNDCT_OP.RESIZE_3D, NNDCT_OP.RESIZE_NEAREST_3D, NNDCT_OP.REORG,
            NNDCT_OP.CORRELATION1D_ELEMWISE, NNDCT_OP.CORRELATION2D_ELEMWISE,
            NNDCT_OP.COST_VOLUME
        ]

        special_ops_fn = {
            NNDCT_OP.RESHAPE: shape_attr_transform_fn,
            NNDCT_OP.CONCAT: axis_attr_transform_fn,
            NNDCT_OP.STRIDED_SLICE: slice_attr_transform_fn,
            NNDCT_OP.SUM: reduce_op_attr_transform_fn,
            NNDCT_OP.MAX: reduce_op_attr_transform_fn,
            NNDCT_OP.MEAN: reduce_op_attr_transform_fn,
            NNDCT_OP.SHAPE: axis_attr_transform_fn,
            NNDCT_OP.SOFTMAX: axis_attr_transform_fn,
            NNDCT_OP.ZEROS: shape_attr_transform_fn,
        }

        # collect insert point for transpose
        insert_pos = []
        for node in self._dev_graph.nodes:
            if node.op.type in implicit_ops:
                insert_pos.append(node)

        swim_transpose = defaultdict(list)
        swim_in_transpose = defaultdict(list)
        sink_transpose = defaultdict(list)

        for node in insert_pos:
            tranpose_out_order = tuple(
                _find_swim_order(node.out_tensors[0].ndim))
            swim_transpose[tranpose_out_order].append(node)
            tranpose_in_order = tuple(_find_swim_order(
                node.in_tensors[0].ndim))
            swim_in_transpose[node] = tranpose_in_order
            tranpose_out_order = tuple(
                _find_sink_order(node.out_tensors[0].ndim))
            sink_transpose[tranpose_out_order].append(node)

        nodes_need_to_remove = []
        transpose_insert_between_swim = defaultdict(list)
        visited = []
        # swim_transpose_order, nodes = next(iter(swim_transpose.items()))
        for swim_transpose_order, nodes in swim_transpose.items():
            for insert_node in nodes:
                q = deque()
                q.append(insert_node)
                visited.append(insert_node)
                insert_node.transpose_out_order = swim_transpose_order
                insert_node.transpose_in_order = swim_in_transpose[insert_node]
                while len(q) > 0:
                    node = q.popleft()
                    for pn in self._dev_graph.parents(node):
                        if pn not in visited:

                            if not _have_special_layout(
                                    pn) or pn.op.type in implicit_ops:
                                continue

                            elif pn.op.type in [
                                    NNDCT_OP.INPUT, NNDCT_OP.QUANT_STUB,
                                    NNDCT_OP.CONST, NNDCT_OP.ZEROS
                            ] or _is_dim_transparent(pn) and (
                                    not _is_permute_op(pn)) and (
                                        not _is_custom_op(pn)):
                                pn.transpose_out_order = node.transpose_in_order
                                pn.transpose_in_order = pn.transpose_out_order
                                if pn.op.type in special_ops_fn:
                                    special_ops_fn[pn.op.type](
                                        pn, pn.transpose_out_order)
                                q.append(pn)
                                visited.append(pn)

                            else:
                                # pn.transpose_out_order = [0, 2, 3, 1]
                                transpose_insert_between_swim[
                                    swim_transpose_order].append((pn, node))

        index = 0
        for transpose_order, node_pairs in transpose_insert_between_swim.items(
        ):
            for pn, cn in node_pairs:
                node_name = "_".join([pn.name, "swim_transpose", f"{index}"])
                op = base_op.Permute(NNDCT_OP.PERMUTE)
                new_node = Node(node_name,
                                op=op,
                                dtype=pn.dtype,
                                in_quant_part=pn.in_quant_part)
                new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                       list(transpose_order))
                self._dev_graph.insert_node_between_nodes(new_node, pn, cn)
                nodes_need_to_remove.append(new_node)
                index += 1

        if transpose_insert_between_swim:
            self._dev_graph.reconnect_nodes()

        # debug
        # print("#####swim######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)

        transpose_insert_between_sink = defaultdict(list)
        visited = []
        for node in self._dev_graph.nodes:
            if node.transpose_out_order:
                nodes = sink_transpose[tuple(
                    _find_sink_order(len(node.transpose_out_order)))]
                if node not in nodes:
                    nodes.append(node)

        for sink_transpose_order, nodes in sink_transpose.items():
            for insert_node in nodes:
                if insert_node not in visited:
                    q = deque()
                    q.append(insert_node)
                    visited.append(insert_node)
                    while len(q) > 0:
                        node = q.popleft()
                        for cn in self._dev_graph.children(node):
                            if cn not in visited:
                                if cn.op.type in implicit_ops or _is_terminate_op(
                                        cn):
                                    continue
                                elif cn.op.type == NNDCT_OP.SHAPE:
                                    visited.append(cn)
                                    if node.transpose_out_order:
                                        special_ops_fn[cn.op.type](
                                            cn, node.transpose_out_order)
                                        continue
                                elif cn.transpose_out_order:
                                    q.append(cn)
                                    visited.append(cn)
                                elif _is_dim_transparent(cn) and (
                                        not _is_permute_op(cn)) and (
                                            not _is_custom_op(cn)):
                                    cn.transpose_in_order = node.transpose_out_order
                                    cn.transpose_out_order = cn.transpose_in_order
                                    q.append(cn)
                                    visited.append(cn)
                                    if cn.op.type in special_ops_fn:
                                        special_ops_fn[cn.op.type](
                                            cn, cn.transpose_out_order)
                                else:
                                    transpose_insert_between_sink[
                                        sink_transpose_order].append(
                                            (node, cn))

        index = 0
        for transpose_order, node_pairs in transpose_insert_between_sink.items(
        ):
            for pn, cn in node_pairs:

                node_name = "_".join([pn.name, "sink_transpose", f"{index}"])
                op = base_op.Permute(NNDCT_OP.PERMUTE)
                new_node = Node(node_name,
                                op=op,
                                dtype=pn.dtype,
                                in_quant_part=cn.in_quant_part)
                new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                       list(transpose_order))
                self._dev_graph.insert_node_between_nodes(new_node, pn, cn)

                nodes_need_to_remove.append(new_node)
                index += 1

        if transpose_insert_between_sink:
            self._dev_graph.reconnect_nodes()

        # debug
        # print("#####sink######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)
        neighbor_broadcast = {}
        for node in self._dev_graph.nodes:
            if len(node.in_nodes) <= 1 or node in implicit_ops:
                continue
            if all([
                    node.transpose_out_order is None
                    for node in self._dev_graph.parents(node)
            ]) or all([
                    node.transpose_out_order is not None
                    for node in self._dev_graph.parents(node)
            ]):
                continue
            #if node.out_tensors[0].dtype != "float32":
            #  continue
            transpose_order = None
            for pn in self._dev_graph.parents(node):
                transpose_order = pn.transpose_out_order
                if transpose_order is not None:
                    break

            neighbor_broadcast[node] = transpose_order

        have_neighbors = False
        for node, transpose_order in neighbor_broadcast.items():
            index = 0
            for pn in self._dev_graph.parents(node):
                if pn.transpose_out_order is None and pn.out_tensors[
                        0].ndim and node.out_tensors[0].ndim and pn.out_tensors[
                            0].ndim == node.out_tensors[0].ndim:
                    # pn.transpose_out_order = node.transpose_out_order
                    node_name = "_".join(
                        [node.name, "neighbor_transpose", f"{index}"])
                    op = base_op.Permute(NNDCT_OP.PERMUTE)
                    new_node = Node(node_name,
                                    op=op,
                                    dtype=node.dtype,
                                    in_quant_part=pn.in_quant_part)
                    new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                           list(transpose_order))
                    self._dev_graph.insert_node_between_nodes(
                        new_node, pn, node)

                    index += 1

                    nodes_need_to_remove.append(new_node)
                    have_neighbors = True

        if have_neighbors:
            self._dev_graph.reconnect_nodes()

        # Debug
        # print("####neightbor######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)
        # remove consecutive transpose

        def merge_father_and_child(node, visited, transpose_group,
                                   reserverd_nodes):
            visited.append(node)
            if _is_permute_op(node):
                if node.out_nodes and all([
                        _is_permute_op(cn)
                        for cn in self._dev_graph.children(node)
                ]):
                    transpose_group.append(node)
                else:
                    transpose_group.append(node)

                    order = []
                    reserved_trans = None
                    for trans in transpose_group:
                        if trans not in nodes_need_to_remove:
                            reserved_trans = trans

                        if not order:
                            order = trans.node_attr(trans.op.AttrName.ORDER)
                        else:
                            new_order = len(order) * [None]
                            tmp_order = trans.node_attr(
                                trans.op.AttrName.ORDER)
                            for i in range(len(order)):
                                t_i = tmp_order[i]
                                new_order[i] = order[t_i]
                            order = new_order

                    if reserved_trans is None:
                        reserved_trans = transpose_group[-1]

                    reserved_trans.set_node_attr(
                        reserved_trans.op.AttrName.ORDER, order)
                    reserverd_nodes.append(reserved_trans)

                    transpose_group.clear()

            for cn in self._dev_graph.children(node):
                if cn not in visited:
                    merge_father_and_child(cn, visited, transpose_group,
                                           reserverd_nodes)

        def merge_brothers(reserverd_nodes):
            remove_nodes = []
            for node in self._dev_graph.nodes:
                if len(node.out_nodes) > 1 and all([
                        _is_permute_op(cn)
                        for cn in self._dev_graph.children(node)
                ]):
                    need_merge = True
                    order = None
                    for trans_node in self._dev_graph.children(node):
                        if order is not None:
                            if order != trans_node.node_attr(
                                    trans_node.op.AttrName.ORDER):
                                need_merge = False
                                break
                        else:
                            order = trans_node.node_attr(
                                trans_node.op.AttrName.ORDER)

                    if need_merge:
                        reserverd_node = None
                        for trans_node in self._dev_graph.children(node):
                            if trans_node not in nodes_need_to_remove:
                                reserverd_node = trans_node

                        if reserverd_node is None:
                            reserverd_node = self._dev_graph.children(node)[0]

                        for trans_node in self._dev_graph.children(node):
                            if trans_node is not reserverd_node and trans_node in reserverd_nodes:
                                remove_nodes.append(trans_node)

                                out_tensor = trans_node.out_tensors[0]
                                out_tensor.replace_uses_with(
                                    reserverd_node.out_tensors[0])

            for node in remove_nodes:
                node.destroy()

            if remove_nodes:
                self._dev_graph.reconnect_nodes()

        source_nodes = []
        for node in self._dev_graph.nodes:
            if not node.in_tensors:
                source_nodes.append(node)

        transpose_group = []
        reserverd_nodes = []
        visited = []
        for source in source_nodes:
            merge_father_and_child(source, visited, transpose_group,
                                   reserverd_nodes)

        nodes_need_to_remove = [
            node for node in nodes_need_to_remove
            if node not in reserverd_nodes
        ]

        for node in reserverd_nodes:
            order = node.node_attr(node.op.AttrName.ORDER)
            keep_order = True
            if any([index != dim for index, dim in enumerate(order)]):
                keep_order = False
            if keep_order:
                nodes_need_to_remove.append(node)

        for node in nodes_need_to_remove:
            self._dev_graph.remove_node(node)

        merge_brothers(reserverd_nodes)

        # debug
        # print("#####finalize######")
        # for node in self._dev_graph.nodes:
        #   print(node.op.type, node.name, node.transpose_out_order)

        def delete_transpose_of_correlation(self):
            nodes_need_to_delete_for_special_ops = []
            nodes_need_to_insert_aster_special_ops = []
            nodes_need_to_merge_for_special_ops = []
            for node in self._dev_graph.nodes:
                if node.op.type == NNDCT_OP.MEAN and not node.node_attr(
                        node.op.AttrName.KEEP_DIMS
                ) and self._dev_graph.parents(node):
                    pn = self._dev_graph.parents(node)[0]
                    if pn.in_tensors and _is_permute_op(
                            pn) and self._dev_graph.parents(pn):
                        gpn = self._dev_graph.parents(pn)[0]
                        if gpn.op.type in [
                                NNDCT_OP.CORRELATION1D_ELEMWISE,
                                NNDCT_OP.CORRELATION2D_ELEMWISE
                        ] and node.out_tensors[0].ndim and gpn.out_tensors[
                                0].ndim == 5 and node.out_tensors[0].ndim == 4:

                            nodes_need_to_delete_for_special_ops.append(pn)

                            node.transpose_in_order = tuple(
                                _find_swim_order(5))
                            node.transpose_out_order = tuple(
                                _find_swim_order(4))
                            special_ops_fn[node.op.type](
                                node, node.transpose_in_order)

                            nodes_need_to_insert_aster_special_ops.append(node)
            index = 0
            for node in nodes_need_to_insert_aster_special_ops:
                cn = self._dev_graph.children(node)[0]
                node_name = "_".join([node.name, "sink_transpose", f"{index}"])
                op = base_op.Permute(NNDCT_OP.PERMUTE)
                new_node = Node(node_name,
                                op=op,
                                dtype=node.dtype,
                                in_quant_part=node.in_quant_part)
                new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                       tuple(_find_sink_order(4)))
                self._dev_graph.insert_node_between_nodes(new_node, node, cn)
                nodes_need_to_merge_for_special_ops.append(new_node)
                index += 1

            for node in nodes_need_to_delete_for_special_ops:
                self._dev_graph.remove_node(node)

            source_nodes = []
            for node in self._dev_graph.nodes:
                if not node.in_tensors:
                    source_nodes.append(node)

            transpose_group = []
            reserverd_nodes = []
            visited = []
            for source in nodes_need_to_merge_for_special_ops:
                merge_father_and_child(source, visited, transpose_group,
                                       reserverd_nodes)

            nodes_need_to_merge_for_special_ops = [
                node for node in nodes_need_to_merge_for_special_ops
                if node not in reserverd_nodes
            ]

            for node in reserverd_nodes:
                order = node.node_attr(node.op.AttrName.ORDER)
                keep_order = True
                if any([index != dim for index, dim in enumerate(order)]):
                    keep_order = False
                if keep_order:
                    nodes_need_to_merge_for_special_ops.append(node)

            for node in nodes_need_to_merge_for_special_ops:
                self._dev_graph.remove_node(node)

            merge_brothers(reserverd_nodes)

        delete_transpose_of_correlation(self)
示例#25
0
    def _convert_node(self, raw_node, scope=None):
        if scope is None:
            assert self.cur_graph
            node_scope = self.cur_graph.name
        else:
            node_scope = scope

        nndct_node = Node(
            name=get_full_name(node_scope, raw_node.name),
            dtype=self.convert_dtype(raw_node.dtype),
        )
        nndct_node.source_range = raw_node.source_range
        nndct_node.scope_name = raw_node.scope_name
        if nndct_node.name in self.cur_graph:
            return self.cur_graph.node(nndct_node.name)

        # nndct_node.raw_kind = raw_node.kind
        # self.converted_node.add(raw_node)
        nndct_node.schema = raw_node.schema
        nndct_node.is_custom_extension = raw_node.is_custom_pyop
        nndct_node.caller = raw_node.pyobj
        nndct_node.owning_block = self.cur_block
        nndct_node.owning_graph = self.cur_graph
        for out in raw_node.outputs:
            full_name = get_full_name(node_scope, out.name)
            if self.cur_graph and self.cur_graph.is_tensor_in_graph(full_name):
                nndct_node.add_out_tensor(self.cur_graph.tensor(full_name))
            else:
                nndct_tensor = self._convert_tensor(out, node_scope)
                nndct_node.add_out_tensor(nndct_tensor)

        for ip in raw_node.flatten_inputs:
            full_name = get_full_name(node_scope, ip.name)
            if self.cur_graph and self.cur_graph.is_tensor_in_graph(full_name):
                nndct_node.add_in_tensor(self.cur_graph.tensor(full_name))
            elif not raw_node.outputs:
                # For Return node
                nndct_tensor = self._convert_tensor(ip, node_scope)
                nndct_node.add_in_tensor(nndct_tensor)

            if self.cur_graph and full_name in self.cur_graph.param_names():
                self.node_params[nndct_node].append(
                    self.cur_graph.tensor(full_name))

        #from ipdb import set_trace
        #set_trace()

        node_input_args = []
        if not raw_node.inputs:
            node_input_args.extend(
                [self.get_nndct_value(i) for i in raw_node.outputs])
        else:
            node_input_args.extend(
                [self.get_nndct_value(i) for i in raw_node.inputs])

        nndct_node.op = self._create_op(raw_node.kind, nndct_node,
                                        node_input_args)

        return nndct_node
示例#26
0
        def delete_transpose_of_correlation(self):
            nodes_need_to_delete_for_special_ops = []
            nodes_need_to_insert_aster_special_ops = []
            nodes_need_to_merge_for_special_ops = []
            for node in self._dev_graph.nodes:
                if node.op.type == NNDCT_OP.MEAN and not node.node_attr(
                        node.op.AttrName.KEEP_DIMS
                ) and self._dev_graph.parents(node):
                    pn = self._dev_graph.parents(node)[0]
                    if pn.in_tensors and _is_permute_op(
                            pn) and self._dev_graph.parents(pn):
                        gpn = self._dev_graph.parents(pn)[0]
                        if gpn.op.type in [
                                NNDCT_OP.CORRELATION1D_ELEMWISE,
                                NNDCT_OP.CORRELATION2D_ELEMWISE
                        ] and node.out_tensors[0].ndim and gpn.out_tensors[
                                0].ndim == 5 and node.out_tensors[0].ndim == 4:

                            nodes_need_to_delete_for_special_ops.append(pn)

                            node.transpose_in_order = tuple(
                                _find_swim_order(5))
                            node.transpose_out_order = tuple(
                                _find_swim_order(4))
                            special_ops_fn[node.op.type](
                                node, node.transpose_in_order)

                            nodes_need_to_insert_aster_special_ops.append(node)
            index = 0
            for node in nodes_need_to_insert_aster_special_ops:
                cn = self._dev_graph.children(node)[0]
                node_name = "_".join([node.name, "sink_transpose", f"{index}"])
                op = base_op.Permute(NNDCT_OP.PERMUTE)
                new_node = Node(node_name,
                                op=op,
                                dtype=node.dtype,
                                in_quant_part=node.in_quant_part)
                new_node.set_node_attr(new_node.op.AttrName.ORDER,
                                       tuple(_find_sink_order(4)))
                self._dev_graph.insert_node_between_nodes(new_node, node, cn)
                nodes_need_to_merge_for_special_ops.append(new_node)
                index += 1

            for node in nodes_need_to_delete_for_special_ops:
                self._dev_graph.remove_node(node)

            source_nodes = []
            for node in self._dev_graph.nodes:
                if not node.in_tensors:
                    source_nodes.append(node)

            transpose_group = []
            reserverd_nodes = []
            visited = []
            for source in nodes_need_to_merge_for_special_ops:
                merge_father_and_child(source, visited, transpose_group,
                                       reserverd_nodes)

            nodes_need_to_merge_for_special_ops = [
                node for node in nodes_need_to_merge_for_special_ops
                if node not in reserverd_nodes
            ]

            for node in reserverd_nodes:
                order = node.node_attr(node.op.AttrName.ORDER)
                keep_order = True
                if any([index != dim for index, dim in enumerate(order)]):
                    keep_order = False
                if keep_order:
                    nodes_need_to_merge_for_special_ops.append(node)

            for node in nodes_need_to_merge_for_special_ops:
                self._dev_graph.remove_node(node)

            merge_brothers(reserverd_nodes)
示例#27
0
def sub(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn:
    operand1, operand2 = node.node_attr(
        node.op.AttrName.INPUT), node.node_attr(node.op.AttrName.OTHER)
    _sub(xgraph, node.name, operand1, operand2, quant_config)