def squeeze(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: if node.in_tensors[0].ndim == 4 and len( node.node_attr(node.op.AttrName.DIMS)) == 1: attrs: Dict[str, Any] = {} attrs["order"] = [0, 3, 1, 2] # resume dimension to NCHW input_ops: Dict[str, List[Op]] = {} input_list = [] for input in node.in_nodes: input_op = xgraph.get_op_by_name(input) input_list.append(input_op) input_ops["input"] = input_list xgraph.create_fixed_normal_op(node.name + "_i0", "transpose", quant_config, attrs=attrs, input_ops=input_ops) attrs: Dict[str, Any] = {} dim = node.node_attr(node.op.AttrName.DIMS)[0] dim = transformed_axis("NHWC", "NCHW", ndim=4, dim=dim) attrs["axis"] = [dim] input_ops: Dict[str, List[Op]] = {} input_ops["input"] = [xgraph.get_op_by_name(node.name + "_i0")] xgraph.create_fixed_normal_op(node.name, "squeeze", quant_config, attrs=attrs, input_ops=input_ops) else: to_xir("squeeze")(xgraph, node, quant_config)
def default_xop(xop_type: str, xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: input_ops: Dict[str, List["xir.Op"]] = {} if node.has_bound_params(): for param_name, param_tensor in node.op.params.items(): param = xgraph.get_op_by_name(param_tensor.name) input_ops[param_name.name.lower()] = [param] input_list = [] for input in node.in_tensors: if node.has_bound_params() and input.is_param_tensor(): continue elif input.is_param_tensor(): input_op = xgraph.get_op_by_name(input.name) else: input_op = xgraph.get_op_by_name(input.node.name) input_list.append(input_op) input_ops["input"] = xgraph.create_input_fix_ops(input_list, node.name, quant_config) attrs = _get_xir_attr_from_node(node) xgraph.create_fixed_normal_op(node.name, xop_type, quant_config, attrs=attrs, input_ops=input_ops)
def conv_transpose_3d(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: attrs = _get_xir_attr_from_node(node) attrs['kernel'] = attrs['kernel'][::-1] attrs['stride'] = attrs['stride'][::-1] attrs['dilation'] = attrs['dilation'][::-1] attrs['pad'] = list( itertools.chain.from_iterable([[pad] * 2 for pad in attrs['pad'][::-1]])) print(attrs) input_ops: Dict[str, List["xir.Op"]] = {} if node.has_bound_params(): for param_name, param_tensor in node.op.params.items(): param = xgraph.get_op_by_name(param_tensor.name) input_ops[param_name.name.lower()] = [param] input_list = [] for input in node.in_tensors: if node.has_bound_params() and input.is_param_tensor(): continue elif input.is_param_tensor(): input_op = xgraph.get_op_by_name(input.name) else: input_op = xgraph.get_op_by_name(input.node.name) input_list.append(input_op) input_ops["input"] = xgraph.create_input_fix_ops(input_list, node.name, quant_config) xgraph.create_fixed_normal_op(node.name, "transposed_conv3d", quant_config, attrs=attrs, input_ops=input_ops)
def permute_invar_op(xop_type, xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: if not node.node_attr(node.op.AttrName.KEEP_DIMS) \ and node.in_tensors[0].ndim == 4 \ and len(node.node_attr(node.op.AttrName.DIMS)) == 1 \ and node.node_attr(node.op.AttrName.DIMS)[0] != 3: layout = ["N", "H", "W", "C"] del layout[node.node_attr(node.op.AttrName.DIMS)[0]] # create mean which keep_dim is True attrs: Dict[str, Any] = {} attrs["axis"] = node.node_attr(node.op.AttrName.DIMS) attrs["keep_dims"] = True input_ops: Dict[str, List[Op]] = {} input_list = [] for input in node.in_nodes: input_op = xgraph.get_op_by_name(input) input_list.append(input_op) input_ops["input"] = xgraph.create_input_fix_ops( input_list, node.name, quant_config) xgraph.create_fixed_normal_op(node.name + "_i0", xop_type, quant_config, attrs=attrs, input_ops=input_ops) attrs: Dict[str, Any] = {} if layout == ["N", "H", "C"]: attrs["order"] = [0, 3, 1, 2] else: attrs["order"] = [0, 3, 2, 1] # resume dimension to NCHW input_ops: Dict[str, List[Op]] = {} input_ops["input"] = [xgraph.get_op_by_name(node.name + "_i0")] xgraph.create_fixed_normal_op(node.name + "_i1", "transpose", quant_config, attrs=attrs, input_ops=input_ops) attrs: Dict[str, Any] = {} if layout == ["N", "H", "C"]: attrs["axis"] = [3] else: attrs["axis"] = [2] input_ops: Dict[str, List[Op]] = {} input_ops["input"] = [xgraph.get_op_by_name(node.name + "_i1")] xgraph.create_fixed_normal_op(node.name, "squeeze", quant_config, attrs=attrs, input_ops=input_ops) else: to_xir(xop_type)(xgraph, node, quant_config)
def avgpool(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: needScale = False scale = 1.0 if node.node_attr(node.op.AttrName.KERNEL) == [3, 3]: needScale = True scale = 9.0 * 7.0 / 64.0 elif node.node_attr(node.op.AttrName.KERNEL) == [5, 5]: needScale = True scale = 25.0 * 10.0 / 256.0 elif node.node_attr(node.op.AttrName.KERNEL) in [[6, 6], [3, 6], [6, 3]]: needScale = True scale = 36.0 * 7.0 / 256.0 elif node.node_attr(node.op.AttrName.KERNEL) == [7, 7]: needScale = True scale = 49.0 * 21.0 / 1024.0 elif node.node_attr(node.op.AttrName.KERNEL) == [14, 14]: needScale = True scale = 196.0 * 21.0 / 4096.0 if needScale: attrs = _get_attr_from_node(node) # attrs: Dict[str, Any] = {} # for attr_name, attr_value in node.op.attrs.items(): # attrs[attr_name.value] = _Converter.to_xir_attr_value(attr_name.value, attr_value.value) input_ops: Dict[str, List[Op]] = {} input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])] input_ops["input"] = xgraph.create_input_fix_ops( input_ops["input"], node.name, quant_config) xgraph.create_fixed_normal_op(node.name + "_i0", "avgpool2d", quant_config, attrs=attrs, input_ops=input_ops) scale = [scale] xgraph.create_fixed_const_op(name=node.name + "_i1", data=np.array(scale, dtype=np.float32), quant_info=quant_config) input_ops: Dict[str, List[Op]] = {} input_ops["input"] = [ xgraph.get_op_by_name(node.name + "_i0"), xgraph.get_op_by_name(node.name + "_i1") ] xgraph.create_fixed_normal_op(node.name, "mul", quant_config, input_ops=input_ops) else: to_xir("avgpool2d")(xgraph, node, quant_config)
def resize(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: """ resize is a macro operator, including concat , resize """ attrs: Dict[str, Any] = {} # attrs["scale"] = node.node_attr(node.op.AttrName.SCALE) attrs["align_corners"] = node.node_attr(node.op.AttrName.ALIGN_CORNERS) attrs["half_pixel_centers"] = node.node_attr( node.op.AttrName.HALF_PIXEL_CENTERS) attrs["mode"] = node.node_attr(node.op.AttrName.MODE) # attrs["mode"] = {0: "NEAREST", 3: "BILINEAR"}.get(attrs["mode"]) size = node.node_attr(node.op.AttrName.SIZE) scale = node.node_attr(node.op.AttrName.SCALE) # if size[0] == 0 and size[1] == 0: if all([s == 0 for s in size]): attrs["scale"] = scale input_ops: Dict[str, List["xir.Op"]] = {} input_list = [] for input in node.in_nodes: input_op = xgraph.get_op_by_name(input) input_list.append(input_op) input_ops["input"] = xgraph.create_input_fix_ops( input_list, node.name, quant_config) xgraph.create_fixed_normal_op(node.name, "resize", quant_config, attrs=attrs, input_ops=input_ops) else: sub_pack_op, pack_list = _pack(xgraph, node, "size", size, quant_config) input_ops: Dict[str, List["xir.Op"]] = {} input_ops["size"] = [sub_pack_op] input_list = [] for input in node.in_nodes: input_op = xgraph.get_op_by_name(input) input_list.append(input_op) input_ops["input"] = input_list input_ops["input"] = [ op for op in input_ops["input"] if op.get_name() not in [i.get_name() for i in pack_list] ] input_ops["input"] = xgraph.create_input_fix_ops( input_ops["input"], node.name, quant_config) xgraph.create_fixed_normal_op(node.name, "resize", quant_config, attrs=attrs, input_ops=input_ops)
def _generate_scale_node(node, scale_param, node_idx): node_name = node.name tensor_out = node.out_tensors[0] out_tensor_shape = tensor_out.shape scale_op = TorchBaseOperation(NNDCT_OP.CHANNEL_SCALE, "Channel_Scale", force_to_primitive=True) scale_op.set_config("channel_scale", scale_param) scale_op.set_config("input", tensor_out) if node.op.type == NNDCT_OP.CHANNEL_SCALE: scale_node_name = node_name[:node_name.rfind(".")] + ".2" else: scale_node_name = "/".join([node_name, "channel_scale.1"]) scale_node = Node(scale_node_name, op=scale_op, dtype="float32", idx=node_idx, in_quant_part=False) out_tensor = Tensor(name=f"{scale_node_name}.0", node=scale_node, shape=out_tensor_shape, dtype="float32", layout=tensor_out.layout) scale_node.out_tensors.append(out_tensor) scale_node.in_tensors.append(tensor_out) return scale_node
def shape(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: r""" nndct shape is a macro operator, including shape, stridedslice """ # raise NotImplementedError("shape") input_list = [] shape_input_ops: Dict[str, List[Op]] = {} for input in node.in_nodes: input_op = xgraph.get_op_by_name(input) input_list.append(input_op) shape_input_ops["input"] = input_list sub_op_shape = xgraph.create_fixed_normal_op(node.name + "_i0", "shape", quant_config, input_ops=shape_input_ops) attrs: Dict[str, Any] = {} strided_slice_input_ops: Dict[str, List[Op]] = {} strided_slice_input_ops["input"] = [sub_op_shape] dim = node.node_attr(node.op.AttrName.AXIS) attrs["begin"] = [dim] attrs["end"] = [dim + 1] xgraph.create_fixed_normal_op(node.name, "strided_slice", quant_config, attrs=attrs, input_ops=strided_slice_input_ops)
def __call__(self, graph, node_name, op, num_out_tensors, shape=None, in_tensors=None, in_quant_part=True): node_name = get_full_name(graph.name, node_name) node = Node(node_name, op=op, dtype="float32", idx=self._idx, in_quant_part=in_quant_part) # print(f"{node.name} quant state: {node.in_quant_part}") for i in range(num_out_tensors): tensor = Tensor(name=f"{node_name}_{i}", node=node, shape=shape) node.out_tensors.append(tensor) if in_tensors: for tensor in in_tensors: node.in_tensors.append(tensor) graph.add_node(node) self._idx += 1
def zeros(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: shape = node.node_attr(node.op.AttrName.SHAPE) data = np.zeros(shape, dtype=_Converter.to_numpy_dtype(node.out_tensors[0].dtype)) xgraph.create_fixed_const_op(name=node.name, data=data, quant_info=quant_config)
def _gen_module_name(self, node: Node): module_name = TorchSymbol.MODULE_NAME_SEPERATOR.join( [TorchSymbol.MODULE_BASE_SYMBOL, str(self._global_idx)]) node.idx = self._global_idx self._global_idx += 1 self._module_names[node.name] = module_name return module_name
def avgpool(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: scale = 1.0 if node.node_attr(node.op.AttrName.KERNEL) == [3, 3]: scale = 9.0 * 7.0 / 64.0 elif node.node_attr(node.op.AttrName.KERNEL) == [5, 5]: scale = 25.0 * 10.0 / 256.0 elif node.node_attr(node.op.AttrName.KERNEL) in [[6, 6], [3, 6], [6, 3]]: scale = 36.0 * 7.0 / 256.0 elif node.node_attr(node.op.AttrName.KERNEL) == [7, 7]: scale = 49.0 * 21.0 / 1024.0 elif node.node_attr(node.op.AttrName.KERNEL) == [14, 14]: scale = 196.0 * 21.0 / 4096.0 else: rec = node.node_attr(node.op.AttrName.KERNEL)[0] * node.node_attr( node.op.AttrName.KERNEL)[1] max_factor = math.ceil(math.log(rec * 128, 2)) diff = 1.0 multi_factor = 0.0 shift_factor = 0.0 for shift_factor_ in range(max_factor): factor = round((2**shift_factor_) / rec) diff_ = abs(factor / (2**shift_factor_) - 1 / rec) if diff_ < diff: multi_factor = factor diff = diff_ shift_factor = shift_factor_ scale = rec * multi_factor / (2**shift_factor) attrs = _get_xir_attr_from_node(node) # attrs: Dict[str, Any] = {} # for attr_name, attr_value in node.op.attrs.items(): # attrs[attr_name.value] = _Converter.to_xir_attr_value(attr_name.value, attr_value.value) input_ops: Dict[str, List["xir.Op"]] = {} input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])] input_ops["input"] = xgraph.create_input_fix_ops(input_ops["input"], node.name, quant_config) xgraph.create_fixed_normal_op(node.name + "_i0", "avgpool2d", quant_config, attrs=attrs, input_ops=input_ops) scale = [scale] xgraph.create_fixed_const_op(name=node.name + "_i1", data=np.array(scale, dtype=np.float32), quant_info=quant_config) input_ops: Dict[str, List["xir.Op"]] = {} input_ops["input"] = [ xgraph.get_op_by_name(node.name + "_i0"), xgraph.get_op_by_name(node.name + "_i1") ] xgraph.create_fixed_normal_op(node.name, "mul", quant_config, input_ops=input_ops)
def const_xop(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: data = node.node_attr(node.op.AttrName.DATA) data_type = np.dtype(node.out_tensors[0].dtype) if not isinstance(data, list): data = [data] xgraph.create_fixed_const_op(name=node.name, data=np.array(data, dtype=data_type), quant_info=quant_config)
def _get_module_attrs_map(self, node: Node, torch_op_type: str, torch_op_attrs: Dict[str, Any]): ordered_attrs = OrderedDict() attrs_template = torch_op_attrs if torch_op_attrs and torch_op_attrs != [ 'args', 'kwargs' ] else node.op.configs for name in attrs_template: if hasattr(node.op, name): ordered_attrs[name] = node.node_config(name) return ordered_attrs
def reshape(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: r""" nndct reshape is a macro operator, including pack, reshape """ # raise NotImplementedError("reshape") if node.in_tensors[0].ndim != 4 or node.in_tensors[ 0].layout == Tensor.Layout.NHWC: shape = node.node_attr(node.op.AttrName.SHAPE) sub_op_pack, pack_list = _pack(xgraph, node, "shape", shape, quant_config) input_ops: Dict[str, List[Op]] = {} input_ops["shape"] = [sub_op_pack] input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])] xgraph.create_fixed_normal_op(node.name, "reshape", quant_config, input_ops=input_ops) else: shape = node.node_attr(node.op.AttrName.SHAPE) sub_op_pack, pack_list = _pack(xgraph, node, "shape", shape, quant_config) attrs: Dict[str, Any] = {} # NHWC -> NCHW attrs["order"] = [0, 3, 1, 2] input_ops: Dict[str, List[Op]] = {} input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])] xgraph.create_fixed_normal_op(node.name + "_i0", "transpose", quant_config, attrs=attrs, input_ops=input_ops) input_ops: Dict[str, List[Op]] = {} input_ops["shape"] = [sub_op_pack] input_ops["input"] = [xgraph.get_op_by_name(node.name + "_i0")] xgraph.create_fixed_normal_op(node.name, "reshape", quant_config, input_ops=input_ops)
def binary_op(op_type: str, xgraph: XGraph, node: Node, quant_config: NndctQuantInfo): input, other = node.node_attr(node.op.AttrName.INPUT), node.node_attr( node.op.AttrName.OTHER) if isinstance(input, Tensor) and (not isinstance(other, Tensor)): operand1 = xgraph.get_op_by_name(input.node.name) dtype = other.dtype if isinstance(other, np.ndarray) else type(other) operand2 = np.ones(input.shape, dtype=dtype) * other operand2 = xgraph.create_const_op(f"{node.name}_other", operand2) else: operand1 = xgraph.get_op_by_name(input.node.name) operand2 = xgraph.get_op_by_name(other.node.name) input_ops: Dict[str, List["xir.Op"]] = {} input_ops["input"] = [operand1, operand2] input_ops["input"] = xgraph.create_input_fix_ops(input_ops["input"], node.name, quant_config) xgraph.create_fixed_normal_op(node.name, op_type, quant_config, input_ops=input_ops)
def __call__(self, parser, raw_node, node_scope=''): nndct_node = Node(name=get_full_name(node_scope, raw_node.name), dtype=convert_dtype(raw_node.dtype), idx=raw_node.idx) nndct_node.raw_kind = raw_node.kind nndct_node.schema = raw_node.schema nndct_node.is_custom_extension = raw_node.is_custom_pyop nndct_node.caller = raw_node.pyobj blob_tensor_convertor = TensorConvertor() for op in raw_node.outputs: nndct_tensor = blob_tensor_convertor(node_scope, op) nndct_tensor.node = nndct_node nndct_node.out_tensors.append(nndct_tensor) parser.visited_blob_tensors[op.name] = nndct_tensor for ip in raw_node.flatten_inputs: if ip.name in parser.visited_blob_tensors: nndct_node.in_tensors.append( parser.visited_blob_tensors[ip.name]) elif ip.name in parser.visited_param_tensors: parser.node_params[nndct_node].append( parser.visited_param_tensors[ip.name]) nndct_node.in_tensors.append( parser.visited_param_tensors[ip.name]) if not raw_node.inputs: parser.node_input_args[nndct_node].append( nndct_node.out_tensors[0]) else: parser.node_input_args[nndct_node].extend( [parser.get_nndct_value(i) for i in raw_node.inputs]) return nndct_node
def _get_init_param_str(self, node: Node) -> List[str]: str_list = [] if node.has_bound_params(): return str_list for param_type, param_tensor in node.op.params.items(): if param_tensor.name not in self._tensor_output_map: param_name = param_type.value param_shape = tuple(param_tensor.shape) param_init_str = f"self.{param_name} = torch.nn.parameter.Parameter(torch.Tensor{param_shape})" str_list.append(param_init_str) self._tensor_output_map[ param_tensor.name] = f"self.{param_name}" return str_list
def reshape(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: r""" nndct reshape is a macro operator, including pack, reshape """ shape = node.node_attr(node.op.AttrName.SHAPE) sub_op_pack, pack_list = _pack(xgraph, node, "shape", shape, quant_config) input_ops: Dict[str, List["xir.Op"]] = {} input_ops["shape"] = [sub_op_pack] input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])] xgraph.create_fixed_normal_op(node.name, "reshape", quant_config, input_ops=input_ops)
def reduction_mean(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: attrs = _get_xir_attr_from_node(node) input_ops: Dict[str, List[Op]] = {} input_ops["input"] = [xgraph.get_op_by_name(node.in_nodes[0])] input_ops["input"] = xgraph.create_input_fix_ops(input_ops["input"], node.name, quant_config) if len(node.node_attr(node.op.AttrName.DIMS)) == 1: xgraph.create_fixed_normal_op(node.name + "_i0", "reduction_mean", quant_config, attrs=attrs, input_ops=input_ops) scale = calculate_op_scale( node.in_tensors[0].shape[node.node_attr(node.op.AttrName.DIMS)[0]], node) scale = [scale] xgraph.create_fixed_const_op(name=node.name + "_i1", data=np.array(scale, dtype=np.float32), quant_info=quant_config) input_ops: Dict[str, List[Op]] = {} input_ops["input"] = [ xgraph.get_op_by_name(node.name + "_i0"), xgraph.get_op_by_name(node.name + "_i1") ] xgraph.create_fixed_normal_op(node.name, "mul", quant_config, input_ops=input_ops) else: xgraph.create_fixed_normal_op(node.name, "reduction_mean", quant_config, attrs=attrs, input_ops=input_ops)
def const_xop(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: data = node.node_attr(node.op.AttrName.DATA) data_type = np.dtype(node.out_tensors[0].dtype) data_type = np.float32 if data_type == np.float64 else data_type if not isinstance(data, list): data = [data] data = np.array(data, dtype=data_type) data = np.transpose( data, node.transpose_out_order) if node.transpose_out_order else data xgraph.create_fixed_const_op(name=node.name, data=data, quant_info=quant_config)
def __call__(self, parser, raw_graph, raw_node): nndct_node = Node( name=get_full_name(raw_graph.name, raw_node.name), dtype=convert_dtype(raw_node.dtype), idx=raw_node.idx) blob_tensor_convertor = TensorConvertor() for op in raw_node.outputs: nndct_tensor = blob_tensor_convertor(raw_graph.name, op) nndct_tensor.node = nndct_node nndct_node.out_tensors.append(nndct_tensor) for ip in raw_node.flatten_inputs: if ip.name not in raw_graph.param_names( ) and parser.get_blob_tensor_by_name(ip.name): nndct_node.in_tensors.append(parser.get_blob_tensor_by_name(ip.name)) return nndct_node
def _init_op_and_attrs_str(self, node: Node) -> str: torch_op_type = py_utils.get_torch_op_type(node.op.type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) op_name = py_utils.get_defined_quant_module( torch_op_type ) if not node.has_custom_op() or torch_op_attr.op_class_type in [ TorchOpClassType.PRIMITIVE ] else '' op_name, is_defined_op = (op_name, True) if op_name else (".".join( [TorchSymbol.MODULE_PREFIX, "Module"]), False) attrs_str = "" if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE: attrs_str = self._to_map_str( self._get_module_attrs_map(node, torch_op_type, torch_op_attr.attrs)) if not is_defined_op: attrs_str = f"'{node.op.type}',{attrs_str}" if attrs_str else f"'{node.op.type}'" return op_name, attrs_str
def layout_tranform(self): """layout_transform TORCH(NCHW) -> XIR(NHWC)""" custom2xir = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_TO_XIR_LIST) if custom2xir is None: custom2xir = [] def _find_swim_order(ndim): return { 2: [0, 1], 3: [0, 2, 1], 4: [0, 2, 3, 1], 5: [0, 3, 4, 2, 1] }[ndim] def _find_sink_order(ndim): return { 2: [0, 1], 3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 3, 1, 2] }[ndim] def _is_dim_transparent(node): return node.in_tensors[0].ndim and node.out_tensors[ 0].ndim and node.in_tensors[0].ndim == node.out_tensors[0].ndim def _is_shape_transparent(node): return node.in_tensors[0].shape and node.out_tensors[ 0].shape and node.in_tensors[0].shape == node.out_tensors[ 0].shape def _have_special_layout(node): return node.out_tensors[0].ndim and node.out_tensors[0].ndim >= 3 def _is_custom_op(node): return isinstance( node.op, base_op.CustomOp) and node.op.type not in custom2xir def _is_permute_op(node): return isinstance(node.op, base_op.Permute) def _is_terminate_op(node): return node.op.type == NNDCT_OP.RETURN implicit_ops = [ NNDCT_OP.CONV2D, NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D, NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.MAX_POOL, NNDCT_OP.AVG_POOL, NNDCT_OP.ADAPTIVEAVGPOOL2D, NNDCT_OP.INTERPOLATE, NNDCT_OP.UP_SAMPLING, NNDCT_OP.RESIZE, NNDCT_OP.BATCH_NORM, NNDCT_OP.MAX_POOL1D, NNDCT_OP.CONV1D, NNDCT_OP.CONV3D, NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.DEPTHWISE_CONVTRANSPOSE3D, NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.PIXEL_SHUFFLE, NNDCT_OP.PIXEL_UNSHUFFLE, NNDCT_OP.RESIZE_3D, NNDCT_OP.RESIZE_NEAREST_3D, NNDCT_OP.REORG, NNDCT_OP.CORRELATION1D_ELEMWISE, NNDCT_OP.CORRELATION2D_ELEMWISE, NNDCT_OP.COST_VOLUME ] special_ops_fn = { NNDCT_OP.RESHAPE: shape_attr_transform_fn, NNDCT_OP.CONCAT: axis_attr_transform_fn, NNDCT_OP.STRIDED_SLICE: slice_attr_transform_fn, NNDCT_OP.SUM: reduce_op_attr_transform_fn, NNDCT_OP.MAX: reduce_op_attr_transform_fn, NNDCT_OP.MEAN: reduce_op_attr_transform_fn, NNDCT_OP.SHAPE: axis_attr_transform_fn, NNDCT_OP.SOFTMAX: axis_attr_transform_fn, NNDCT_OP.ZEROS: shape_attr_transform_fn, } # collect insert point for transpose insert_pos = [] for node in self._dev_graph.nodes: if node.op.type in implicit_ops: insert_pos.append(node) swim_transpose = defaultdict(list) swim_in_transpose = defaultdict(list) sink_transpose = defaultdict(list) for node in insert_pos: tranpose_out_order = tuple( _find_swim_order(node.out_tensors[0].ndim)) swim_transpose[tranpose_out_order].append(node) tranpose_in_order = tuple(_find_swim_order( node.in_tensors[0].ndim)) swim_in_transpose[node] = tranpose_in_order tranpose_out_order = tuple( _find_sink_order(node.out_tensors[0].ndim)) sink_transpose[tranpose_out_order].append(node) nodes_need_to_remove = [] transpose_insert_between_swim = defaultdict(list) visited = [] # swim_transpose_order, nodes = next(iter(swim_transpose.items())) for swim_transpose_order, nodes in swim_transpose.items(): for insert_node in nodes: q = deque() q.append(insert_node) visited.append(insert_node) insert_node.transpose_out_order = swim_transpose_order insert_node.transpose_in_order = swim_in_transpose[insert_node] while len(q) > 0: node = q.popleft() for pn in self._dev_graph.parents(node): if pn not in visited: if not _have_special_layout( pn) or pn.op.type in implicit_ops: continue elif pn.op.type in [ NNDCT_OP.INPUT, NNDCT_OP.QUANT_STUB, NNDCT_OP.CONST, NNDCT_OP.ZEROS ] or _is_dim_transparent(pn) and ( not _is_permute_op(pn)) and ( not _is_custom_op(pn)): pn.transpose_out_order = node.transpose_in_order pn.transpose_in_order = pn.transpose_out_order if pn.op.type in special_ops_fn: special_ops_fn[pn.op.type]( pn, pn.transpose_out_order) q.append(pn) visited.append(pn) else: # pn.transpose_out_order = [0, 2, 3, 1] transpose_insert_between_swim[ swim_transpose_order].append((pn, node)) index = 0 for transpose_order, node_pairs in transpose_insert_between_swim.items( ): for pn, cn in node_pairs: node_name = "_".join([pn.name, "swim_transpose", f"{index}"]) op = base_op.Permute(NNDCT_OP.PERMUTE) new_node = Node(node_name, op=op, dtype=pn.dtype, in_quant_part=pn.in_quant_part) new_node.set_node_attr(new_node.op.AttrName.ORDER, list(transpose_order)) self._dev_graph.insert_node_between_nodes(new_node, pn, cn) nodes_need_to_remove.append(new_node) index += 1 if transpose_insert_between_swim: self._dev_graph.reconnect_nodes() # debug # print("#####swim######") # for node in self._dev_graph.nodes: # print(node.op.type, node.name, node.transpose_out_order) transpose_insert_between_sink = defaultdict(list) visited = [] for node in self._dev_graph.nodes: if node.transpose_out_order: nodes = sink_transpose[tuple( _find_sink_order(len(node.transpose_out_order)))] if node not in nodes: nodes.append(node) for sink_transpose_order, nodes in sink_transpose.items(): for insert_node in nodes: if insert_node not in visited: q = deque() q.append(insert_node) visited.append(insert_node) while len(q) > 0: node = q.popleft() for cn in self._dev_graph.children(node): if cn not in visited: if cn.op.type in implicit_ops or _is_terminate_op( cn): continue elif cn.op.type == NNDCT_OP.SHAPE: visited.append(cn) if node.transpose_out_order: special_ops_fn[cn.op.type]( cn, node.transpose_out_order) continue elif cn.transpose_out_order: q.append(cn) visited.append(cn) elif _is_dim_transparent(cn) and ( not _is_permute_op(cn)) and ( not _is_custom_op(cn)): cn.transpose_in_order = node.transpose_out_order cn.transpose_out_order = cn.transpose_in_order q.append(cn) visited.append(cn) if cn.op.type in special_ops_fn: special_ops_fn[cn.op.type]( cn, cn.transpose_out_order) else: transpose_insert_between_sink[ sink_transpose_order].append( (node, cn)) index = 0 for transpose_order, node_pairs in transpose_insert_between_sink.items( ): for pn, cn in node_pairs: node_name = "_".join([pn.name, "sink_transpose", f"{index}"]) op = base_op.Permute(NNDCT_OP.PERMUTE) new_node = Node(node_name, op=op, dtype=pn.dtype, in_quant_part=cn.in_quant_part) new_node.set_node_attr(new_node.op.AttrName.ORDER, list(transpose_order)) self._dev_graph.insert_node_between_nodes(new_node, pn, cn) nodes_need_to_remove.append(new_node) index += 1 if transpose_insert_between_sink: self._dev_graph.reconnect_nodes() # debug # print("#####sink######") # for node in self._dev_graph.nodes: # print(node.op.type, node.name, node.transpose_out_order) neighbor_broadcast = {} for node in self._dev_graph.nodes: if len(node.in_nodes) <= 1 or node in implicit_ops: continue if all([ node.transpose_out_order is None for node in self._dev_graph.parents(node) ]) or all([ node.transpose_out_order is not None for node in self._dev_graph.parents(node) ]): continue #if node.out_tensors[0].dtype != "float32": # continue transpose_order = None for pn in self._dev_graph.parents(node): transpose_order = pn.transpose_out_order if transpose_order is not None: break neighbor_broadcast[node] = transpose_order have_neighbors = False for node, transpose_order in neighbor_broadcast.items(): index = 0 for pn in self._dev_graph.parents(node): if pn.transpose_out_order is None and pn.out_tensors[ 0].ndim and node.out_tensors[0].ndim and pn.out_tensors[ 0].ndim == node.out_tensors[0].ndim: # pn.transpose_out_order = node.transpose_out_order node_name = "_".join( [node.name, "neighbor_transpose", f"{index}"]) op = base_op.Permute(NNDCT_OP.PERMUTE) new_node = Node(node_name, op=op, dtype=node.dtype, in_quant_part=pn.in_quant_part) new_node.set_node_attr(new_node.op.AttrName.ORDER, list(transpose_order)) self._dev_graph.insert_node_between_nodes( new_node, pn, node) index += 1 nodes_need_to_remove.append(new_node) have_neighbors = True if have_neighbors: self._dev_graph.reconnect_nodes() # Debug # print("####neightbor######") # for node in self._dev_graph.nodes: # print(node.op.type, node.name, node.transpose_out_order) # remove consecutive transpose def merge_father_and_child(node, visited, transpose_group, reserverd_nodes): visited.append(node) if _is_permute_op(node): if node.out_nodes and all([ _is_permute_op(cn) for cn in self._dev_graph.children(node) ]): transpose_group.append(node) else: transpose_group.append(node) order = [] reserved_trans = None for trans in transpose_group: if trans not in nodes_need_to_remove: reserved_trans = trans if not order: order = trans.node_attr(trans.op.AttrName.ORDER) else: new_order = len(order) * [None] tmp_order = trans.node_attr( trans.op.AttrName.ORDER) for i in range(len(order)): t_i = tmp_order[i] new_order[i] = order[t_i] order = new_order if reserved_trans is None: reserved_trans = transpose_group[-1] reserved_trans.set_node_attr( reserved_trans.op.AttrName.ORDER, order) reserverd_nodes.append(reserved_trans) transpose_group.clear() for cn in self._dev_graph.children(node): if cn not in visited: merge_father_and_child(cn, visited, transpose_group, reserverd_nodes) def merge_brothers(reserverd_nodes): remove_nodes = [] for node in self._dev_graph.nodes: if len(node.out_nodes) > 1 and all([ _is_permute_op(cn) for cn in self._dev_graph.children(node) ]): need_merge = True order = None for trans_node in self._dev_graph.children(node): if order is not None: if order != trans_node.node_attr( trans_node.op.AttrName.ORDER): need_merge = False break else: order = trans_node.node_attr( trans_node.op.AttrName.ORDER) if need_merge: reserverd_node = None for trans_node in self._dev_graph.children(node): if trans_node not in nodes_need_to_remove: reserverd_node = trans_node if reserverd_node is None: reserverd_node = self._dev_graph.children(node)[0] for trans_node in self._dev_graph.children(node): if trans_node is not reserverd_node and trans_node in reserverd_nodes: remove_nodes.append(trans_node) out_tensor = trans_node.out_tensors[0] out_tensor.replace_uses_with( reserverd_node.out_tensors[0]) for node in remove_nodes: node.destroy() if remove_nodes: self._dev_graph.reconnect_nodes() source_nodes = [] for node in self._dev_graph.nodes: if not node.in_tensors: source_nodes.append(node) transpose_group = [] reserverd_nodes = [] visited = [] for source in source_nodes: merge_father_and_child(source, visited, transpose_group, reserverd_nodes) nodes_need_to_remove = [ node for node in nodes_need_to_remove if node not in reserverd_nodes ] for node in reserverd_nodes: order = node.node_attr(node.op.AttrName.ORDER) keep_order = True if any([index != dim for index, dim in enumerate(order)]): keep_order = False if keep_order: nodes_need_to_remove.append(node) for node in nodes_need_to_remove: self._dev_graph.remove_node(node) merge_brothers(reserverd_nodes) # debug # print("#####finalize######") # for node in self._dev_graph.nodes: # print(node.op.type, node.name, node.transpose_out_order) def delete_transpose_of_correlation(self): nodes_need_to_delete_for_special_ops = [] nodes_need_to_insert_aster_special_ops = [] nodes_need_to_merge_for_special_ops = [] for node in self._dev_graph.nodes: if node.op.type == NNDCT_OP.MEAN and not node.node_attr( node.op.AttrName.KEEP_DIMS ) and self._dev_graph.parents(node): pn = self._dev_graph.parents(node)[0] if pn.in_tensors and _is_permute_op( pn) and self._dev_graph.parents(pn): gpn = self._dev_graph.parents(pn)[0] if gpn.op.type in [ NNDCT_OP.CORRELATION1D_ELEMWISE, NNDCT_OP.CORRELATION2D_ELEMWISE ] and node.out_tensors[0].ndim and gpn.out_tensors[ 0].ndim == 5 and node.out_tensors[0].ndim == 4: nodes_need_to_delete_for_special_ops.append(pn) node.transpose_in_order = tuple( _find_swim_order(5)) node.transpose_out_order = tuple( _find_swim_order(4)) special_ops_fn[node.op.type]( node, node.transpose_in_order) nodes_need_to_insert_aster_special_ops.append(node) index = 0 for node in nodes_need_to_insert_aster_special_ops: cn = self._dev_graph.children(node)[0] node_name = "_".join([node.name, "sink_transpose", f"{index}"]) op = base_op.Permute(NNDCT_OP.PERMUTE) new_node = Node(node_name, op=op, dtype=node.dtype, in_quant_part=node.in_quant_part) new_node.set_node_attr(new_node.op.AttrName.ORDER, tuple(_find_sink_order(4))) self._dev_graph.insert_node_between_nodes(new_node, node, cn) nodes_need_to_merge_for_special_ops.append(new_node) index += 1 for node in nodes_need_to_delete_for_special_ops: self._dev_graph.remove_node(node) source_nodes = [] for node in self._dev_graph.nodes: if not node.in_tensors: source_nodes.append(node) transpose_group = [] reserverd_nodes = [] visited = [] for source in nodes_need_to_merge_for_special_ops: merge_father_and_child(source, visited, transpose_group, reserverd_nodes) nodes_need_to_merge_for_special_ops = [ node for node in nodes_need_to_merge_for_special_ops if node not in reserverd_nodes ] for node in reserverd_nodes: order = node.node_attr(node.op.AttrName.ORDER) keep_order = True if any([index != dim for index, dim in enumerate(order)]): keep_order = False if keep_order: nodes_need_to_merge_for_special_ops.append(node) for node in nodes_need_to_merge_for_special_ops: self._dev_graph.remove_node(node) merge_brothers(reserverd_nodes) delete_transpose_of_correlation(self)
def _convert_node(self, raw_node, scope=None): if scope is None: assert self.cur_graph node_scope = self.cur_graph.name else: node_scope = scope nndct_node = Node( name=get_full_name(node_scope, raw_node.name), dtype=self.convert_dtype(raw_node.dtype), ) nndct_node.source_range = raw_node.source_range nndct_node.scope_name = raw_node.scope_name if nndct_node.name in self.cur_graph: return self.cur_graph.node(nndct_node.name) # nndct_node.raw_kind = raw_node.kind # self.converted_node.add(raw_node) nndct_node.schema = raw_node.schema nndct_node.is_custom_extension = raw_node.is_custom_pyop nndct_node.caller = raw_node.pyobj nndct_node.owning_block = self.cur_block nndct_node.owning_graph = self.cur_graph for out in raw_node.outputs: full_name = get_full_name(node_scope, out.name) if self.cur_graph and self.cur_graph.is_tensor_in_graph(full_name): nndct_node.add_out_tensor(self.cur_graph.tensor(full_name)) else: nndct_tensor = self._convert_tensor(out, node_scope) nndct_node.add_out_tensor(nndct_tensor) for ip in raw_node.flatten_inputs: full_name = get_full_name(node_scope, ip.name) if self.cur_graph and self.cur_graph.is_tensor_in_graph(full_name): nndct_node.add_in_tensor(self.cur_graph.tensor(full_name)) elif not raw_node.outputs: # For Return node nndct_tensor = self._convert_tensor(ip, node_scope) nndct_node.add_in_tensor(nndct_tensor) if self.cur_graph and full_name in self.cur_graph.param_names(): self.node_params[nndct_node].append( self.cur_graph.tensor(full_name)) #from ipdb import set_trace #set_trace() node_input_args = [] if not raw_node.inputs: node_input_args.extend( [self.get_nndct_value(i) for i in raw_node.outputs]) else: node_input_args.extend( [self.get_nndct_value(i) for i in raw_node.inputs]) nndct_node.op = self._create_op(raw_node.kind, nndct_node, node_input_args) return nndct_node
def delete_transpose_of_correlation(self): nodes_need_to_delete_for_special_ops = [] nodes_need_to_insert_aster_special_ops = [] nodes_need_to_merge_for_special_ops = [] for node in self._dev_graph.nodes: if node.op.type == NNDCT_OP.MEAN and not node.node_attr( node.op.AttrName.KEEP_DIMS ) and self._dev_graph.parents(node): pn = self._dev_graph.parents(node)[0] if pn.in_tensors and _is_permute_op( pn) and self._dev_graph.parents(pn): gpn = self._dev_graph.parents(pn)[0] if gpn.op.type in [ NNDCT_OP.CORRELATION1D_ELEMWISE, NNDCT_OP.CORRELATION2D_ELEMWISE ] and node.out_tensors[0].ndim and gpn.out_tensors[ 0].ndim == 5 and node.out_tensors[0].ndim == 4: nodes_need_to_delete_for_special_ops.append(pn) node.transpose_in_order = tuple( _find_swim_order(5)) node.transpose_out_order = tuple( _find_swim_order(4)) special_ops_fn[node.op.type]( node, node.transpose_in_order) nodes_need_to_insert_aster_special_ops.append(node) index = 0 for node in nodes_need_to_insert_aster_special_ops: cn = self._dev_graph.children(node)[0] node_name = "_".join([node.name, "sink_transpose", f"{index}"]) op = base_op.Permute(NNDCT_OP.PERMUTE) new_node = Node(node_name, op=op, dtype=node.dtype, in_quant_part=node.in_quant_part) new_node.set_node_attr(new_node.op.AttrName.ORDER, tuple(_find_sink_order(4))) self._dev_graph.insert_node_between_nodes(new_node, node, cn) nodes_need_to_merge_for_special_ops.append(new_node) index += 1 for node in nodes_need_to_delete_for_special_ops: self._dev_graph.remove_node(node) source_nodes = [] for node in self._dev_graph.nodes: if not node.in_tensors: source_nodes.append(node) transpose_group = [] reserverd_nodes = [] visited = [] for source in nodes_need_to_merge_for_special_ops: merge_father_and_child(source, visited, transpose_group, reserverd_nodes) nodes_need_to_merge_for_special_ops = [ node for node in nodes_need_to_merge_for_special_ops if node not in reserverd_nodes ] for node in reserverd_nodes: order = node.node_attr(node.op.AttrName.ORDER) keep_order = True if any([index != dim for index, dim in enumerate(order)]): keep_order = False if keep_order: nodes_need_to_merge_for_special_ops.append(node) for node in nodes_need_to_merge_for_special_ops: self._dev_graph.remove_node(node) merge_brothers(reserverd_nodes)
def sub(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: operand1, operand2 = node.node_attr( node.op.AttrName.INPUT), node.node_attr(node.op.AttrName.OTHER) _sub(xgraph, node.name, operand1, operand2, quant_config)