def forward(self, *args): inputs = [] def collect_inputs(inputs, value): if isinstance(value, torch.Tensor): inputs.append(value) elif isinstance(value, (tuple, list)): for i in value: collect_inputs(inputs, i) for v in args: collect_inputs(inputs, v) inptus, _ = process_inputs_and_params(self.node, self.quantizer, inputs=inputs) caller_map = GLOBAL_MAP.get_ele(NNDCT_KEYS.NODE_CALLER_MAP) output = caller_map[self.node.name](*args) [output] = post_quant_process(self.node, [output]) return output
def forward(self, *args): inputs = [] def collect_inputs(inputs, value): if isinstance(value, torch.Tensor): inputs.append(value) elif isinstance(value, (tuple, list)): for i in value: collect_inputs(inputs, i) for v in args: collect_inputs(inputs, v) inputs = quantize_tensors(inputs, self.node, tensor_type='input') caller_map = GLOBAL_MAP.get_ele(NNDCT_KEYS.NODE_CALLER_MAP) output = caller_map[self.node.name](*args) output = quantize_tensors([output], self.node)[0] return output
def decorate(func): if op_type in NNDCT_OP.__dict__.values() and (not mapping_to_xir): NndctScreenLogger().error( f"'{op_type}' has been defined in pytorch_nndct, please use other type name." ) exit(1) if not inspect.isfunction(func): RuntimeError("This api only decorate a function object") custom_op_attr_map = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_OP_ATTRS_MAP) if custom_op_attr_map is None: custom_op_attr_map = {} GLOBAL_MAP.set_map(NNDCT_KEYS.CUSTOM_OP_ATTRS_MAP, custom_op_attr_map) if op_type in custom_op_attr_map: NndctScreenLogger().error( f"'{op_type}' can't be registered multiple times.") else: custom_op_attr_map[ op_type] = attrs_list if attrs_list is not None else [] if mapping_to_xir is True: NndctScreenLogger().info(f'`{op_type}` has been mapped to xir.') custom2xir = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_TO_XIR_LIST) if custom2xir is None: custom2xir = [] GLOBAL_MAP.set_map(NNDCT_KEYS.CUSTOM_TO_XIR_LIST, custom2xir) if op_type not in custom2xir: custom2xir.append(op_type) else: raise RuntimeError( f"{op_type} has alrealy been mapped to XIR. Please use this op type instead of custom op." ) @functools.wraps(func) def innner(*args, **kwargs): custom_op = types.new_class(op_type, (torch.autograd.Function, ), {}) custom_op.forward = staticmethod(func) return custom_op.apply(*args, **kwargs) return innner
def forward(self, *args): caller_map = GLOBAL_MAP.get_ele(NNDCT_KEYS.NODE_CALLER_MAP) output = caller_map[self.node.name](*args) [output] = post_quant_process(self.node, [output]) return output
def do_compile( compile_graph: Graph, output_file_name=None, quant_config_info: Optional[NndctQuantInfo] = None, graph_attr_kwargs: Optional[Dict[str, Any]] = None) -> NoReturn: r""" convert nndct graph to xmodel""" # debug # for type, bnfp in quant_config_info.items(): # print(f"{type}\n") # for name, bnfp_value in bnfp.items(): # print(f"{name}:{bnfp_value}\n") if NndctOption.nndct_quant_off.value: quant_config_info = None xgraph = XGraph(compile_graph.name) if graph_attr_kwargs is not None: for name, attr in graph_attr_kwargs.items(): xgraph.graph.set_attr(name, attr) for node in compile_graph.nodes: for param_type, param_tensor in node.op.params.items(): if (node.op.type == NNDCT_OP.BATCH_NORM and param_type not in [ node.op.ParamName.GAMMA, node.op.ParamName.BETA ]): continue if xgraph.get_op_by_name(param_tensor.name): continue # print(f"{node.name}: {param_tensor.name}, {id(param_tensor)}") data = np.copy(param_tensor.data) if node.op.type in [ NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D ] and param_type == node.op.ParamName.WEIGHTS: # OHWI -> OH'W'I reverse the order of ele in both h and w axis data = np.flip(data, (1, 2)) data = np.ascontiguousarray(data) elif node.op.type in [ NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.DEPTHWISE_CONVTRANSPOSE3D ] and param_type == node.op.ParamName.WEIGHTS: # OHWDI -> OH'W'D'I reverse the order of ele in both h and w axis data = np.flip(data, (1, 2, 3)) data = np.ascontiguousarray(data) try: xgraph.create_fixed_const_op(name=param_tensor.name, data=data, quant_info=quant_config_info) except Exception as e: raise AddXopError(param_tensor.name, 'const', str(e)) custom2xir = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_TO_XIR_LIST) if custom2xir: for op_type in custom2xir: NNDCTIR2XIR_CONVERTOR[op_type] = to_xir(op_type) for node in compile_graph.nodes: if node.op.type == NNDCT_OP.RETURN: continue # print("convert...:", node.op.type, node.name, node.in_quant_part) # import sys # sys.stdout.flush() try: NNDCTIR2XIR_CONVERTOR.get(node.op.type, custom_xop)(xgraph, node, quant_config_info) except Exception as e: raise AddXopError(node.name, node.op.type, str(e)) if output_file_name: if quant_config_info is None: output_file_name += '_float' else: output_file_name += '_int' xgraph.export_to_xmodel(output_file_name) return xgraph
def forward(self, *args): caller_map = GLOBAL_MAP.get_ele(NNDCT_KEYS.NODE_CALLER_MAP) output = caller_map[self.node.name](*args) output = quantize_tensors([output], self.node)[0] return output
def layout_tranform(self): """layout_transform TORCH(NCHW) -> XIR(NHWC)""" custom2xir = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_TO_XIR_LIST) if custom2xir is None: custom2xir = [] def _find_swim_order(ndim): return { 2: [0, 1], 3: [0, 2, 1], 4: [0, 2, 3, 1], 5: [0, 3, 4, 2, 1] }[ndim] def _find_sink_order(ndim): return { 2: [0, 1], 3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 3, 1, 2] }[ndim] def _is_dim_transparent(node): return node.in_tensors[0].ndim and node.out_tensors[ 0].ndim and node.in_tensors[0].ndim == node.out_tensors[0].ndim def _is_shape_transparent(node): return node.in_tensors[0].shape and node.out_tensors[ 0].shape and node.in_tensors[0].shape == node.out_tensors[ 0].shape def _have_special_layout(node): return node.out_tensors[0].ndim and node.out_tensors[0].ndim >= 3 def _is_custom_op(node): return isinstance( node.op, base_op.CustomOp) and node.op.type not in custom2xir def _is_permute_op(node): return isinstance(node.op, base_op.Permute) def _is_terminate_op(node): return node.op.type == NNDCT_OP.RETURN implicit_ops = [ NNDCT_OP.CONV2D, NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D, NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.MAX_POOL, NNDCT_OP.AVG_POOL, NNDCT_OP.ADAPTIVEAVGPOOL2D, NNDCT_OP.INTERPOLATE, NNDCT_OP.UP_SAMPLING, NNDCT_OP.RESIZE, NNDCT_OP.BATCH_NORM, NNDCT_OP.MAX_POOL1D, NNDCT_OP.CONV1D, NNDCT_OP.CONV3D, NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.DEPTHWISE_CONVTRANSPOSE3D, NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.PIXEL_SHUFFLE, NNDCT_OP.PIXEL_UNSHUFFLE, NNDCT_OP.RESIZE_3D, NNDCT_OP.RESIZE_NEAREST_3D, NNDCT_OP.REORG, NNDCT_OP.CORRELATION1D_ELEMWISE, NNDCT_OP.CORRELATION2D_ELEMWISE, NNDCT_OP.COST_VOLUME ] special_ops_fn = { NNDCT_OP.RESHAPE: shape_attr_transform_fn, NNDCT_OP.CONCAT: axis_attr_transform_fn, NNDCT_OP.STRIDED_SLICE: slice_attr_transform_fn, NNDCT_OP.SUM: reduce_op_attr_transform_fn, NNDCT_OP.MAX: reduce_op_attr_transform_fn, NNDCT_OP.MEAN: reduce_op_attr_transform_fn, NNDCT_OP.SHAPE: axis_attr_transform_fn, NNDCT_OP.SOFTMAX: axis_attr_transform_fn, NNDCT_OP.ZEROS: shape_attr_transform_fn, } # collect insert point for transpose insert_pos = [] for node in self._dev_graph.nodes: if node.op.type in implicit_ops: insert_pos.append(node) swim_transpose = defaultdict(list) swim_in_transpose = defaultdict(list) sink_transpose = defaultdict(list) for node in insert_pos: tranpose_out_order = tuple( _find_swim_order(node.out_tensors[0].ndim)) swim_transpose[tranpose_out_order].append(node) tranpose_in_order = tuple(_find_swim_order( node.in_tensors[0].ndim)) swim_in_transpose[node] = tranpose_in_order tranpose_out_order = tuple( _find_sink_order(node.out_tensors[0].ndim)) sink_transpose[tranpose_out_order].append(node) nodes_need_to_remove = [] transpose_insert_between_swim = defaultdict(list) visited = [] # swim_transpose_order, nodes = next(iter(swim_transpose.items())) for swim_transpose_order, nodes in swim_transpose.items(): for insert_node in nodes: q = deque() q.append(insert_node) visited.append(insert_node) insert_node.transpose_out_order = swim_transpose_order insert_node.transpose_in_order = swim_in_transpose[insert_node] while len(q) > 0: node = q.popleft() for pn in self._dev_graph.parents(node): if pn not in visited: if not _have_special_layout( pn) or pn.op.type in implicit_ops: continue elif pn.op.type in [ NNDCT_OP.INPUT, NNDCT_OP.QUANT_STUB, NNDCT_OP.CONST, NNDCT_OP.ZEROS ] or _is_dim_transparent(pn) and ( not _is_permute_op(pn)) and ( not _is_custom_op(pn)): pn.transpose_out_order = node.transpose_in_order pn.transpose_in_order = pn.transpose_out_order if pn.op.type in special_ops_fn: special_ops_fn[pn.op.type]( pn, pn.transpose_out_order) q.append(pn) visited.append(pn) else: # pn.transpose_out_order = [0, 2, 3, 1] transpose_insert_between_swim[ swim_transpose_order].append((pn, node)) index = 0 for transpose_order, node_pairs in transpose_insert_between_swim.items( ): for pn, cn in node_pairs: node_name = "_".join([pn.name, "swim_transpose", f"{index}"]) op = base_op.Permute(NNDCT_OP.PERMUTE) new_node = Node(node_name, op=op, dtype=pn.dtype, in_quant_part=pn.in_quant_part) new_node.set_node_attr(new_node.op.AttrName.ORDER, list(transpose_order)) self._dev_graph.insert_node_between_nodes(new_node, pn, cn) nodes_need_to_remove.append(new_node) index += 1 if transpose_insert_between_swim: self._dev_graph.reconnect_nodes() # debug # print("#####swim######") # for node in self._dev_graph.nodes: # print(node.op.type, node.name, node.transpose_out_order) transpose_insert_between_sink = defaultdict(list) visited = [] for node in self._dev_graph.nodes: if node.transpose_out_order: nodes = sink_transpose[tuple( _find_sink_order(len(node.transpose_out_order)))] if node not in nodes: nodes.append(node) for sink_transpose_order, nodes in sink_transpose.items(): for insert_node in nodes: if insert_node not in visited: q = deque() q.append(insert_node) visited.append(insert_node) while len(q) > 0: node = q.popleft() for cn in self._dev_graph.children(node): if cn not in visited: if cn.op.type in implicit_ops or _is_terminate_op( cn): continue elif cn.op.type == NNDCT_OP.SHAPE: visited.append(cn) if node.transpose_out_order: special_ops_fn[cn.op.type]( cn, node.transpose_out_order) continue elif cn.transpose_out_order: q.append(cn) visited.append(cn) elif _is_dim_transparent(cn) and ( not _is_permute_op(cn)) and ( not _is_custom_op(cn)): cn.transpose_in_order = node.transpose_out_order cn.transpose_out_order = cn.transpose_in_order q.append(cn) visited.append(cn) if cn.op.type in special_ops_fn: special_ops_fn[cn.op.type]( cn, cn.transpose_out_order) else: transpose_insert_between_sink[ sink_transpose_order].append( (node, cn)) index = 0 for transpose_order, node_pairs in transpose_insert_between_sink.items( ): for pn, cn in node_pairs: node_name = "_".join([pn.name, "sink_transpose", f"{index}"]) op = base_op.Permute(NNDCT_OP.PERMUTE) new_node = Node(node_name, op=op, dtype=pn.dtype, in_quant_part=cn.in_quant_part) new_node.set_node_attr(new_node.op.AttrName.ORDER, list(transpose_order)) self._dev_graph.insert_node_between_nodes(new_node, pn, cn) nodes_need_to_remove.append(new_node) index += 1 if transpose_insert_between_sink: self._dev_graph.reconnect_nodes() # debug # print("#####sink######") # for node in self._dev_graph.nodes: # print(node.op.type, node.name, node.transpose_out_order) neighbor_broadcast = {} for node in self._dev_graph.nodes: if len(node.in_nodes) <= 1 or node in implicit_ops: continue if all([ node.transpose_out_order is None for node in self._dev_graph.parents(node) ]) or all([ node.transpose_out_order is not None for node in self._dev_graph.parents(node) ]): continue #if node.out_tensors[0].dtype != "float32": # continue transpose_order = None for pn in self._dev_graph.parents(node): transpose_order = pn.transpose_out_order if transpose_order is not None: break neighbor_broadcast[node] = transpose_order have_neighbors = False for node, transpose_order in neighbor_broadcast.items(): index = 0 for pn in self._dev_graph.parents(node): if pn.transpose_out_order is None and pn.out_tensors[ 0].ndim and node.out_tensors[0].ndim and pn.out_tensors[ 0].ndim == node.out_tensors[0].ndim: # pn.transpose_out_order = node.transpose_out_order node_name = "_".join( [node.name, "neighbor_transpose", f"{index}"]) op = base_op.Permute(NNDCT_OP.PERMUTE) new_node = Node(node_name, op=op, dtype=node.dtype, in_quant_part=pn.in_quant_part) new_node.set_node_attr(new_node.op.AttrName.ORDER, list(transpose_order)) self._dev_graph.insert_node_between_nodes( new_node, pn, node) index += 1 nodes_need_to_remove.append(new_node) have_neighbors = True if have_neighbors: self._dev_graph.reconnect_nodes() # Debug # print("####neightbor######") # for node in self._dev_graph.nodes: # print(node.op.type, node.name, node.transpose_out_order) # remove consecutive transpose def merge_father_and_child(node, visited, transpose_group, reserverd_nodes): visited.append(node) if _is_permute_op(node): if node.out_nodes and all([ _is_permute_op(cn) for cn in self._dev_graph.children(node) ]): transpose_group.append(node) else: transpose_group.append(node) order = [] reserved_trans = None for trans in transpose_group: if trans not in nodes_need_to_remove: reserved_trans = trans if not order: order = trans.node_attr(trans.op.AttrName.ORDER) else: new_order = len(order) * [None] tmp_order = trans.node_attr( trans.op.AttrName.ORDER) for i in range(len(order)): t_i = tmp_order[i] new_order[i] = order[t_i] order = new_order if reserved_trans is None: reserved_trans = transpose_group[-1] reserved_trans.set_node_attr( reserved_trans.op.AttrName.ORDER, order) reserverd_nodes.append(reserved_trans) transpose_group.clear() for cn in self._dev_graph.children(node): if cn not in visited: merge_father_and_child(cn, visited, transpose_group, reserverd_nodes) def merge_brothers(reserverd_nodes): remove_nodes = [] for node in self._dev_graph.nodes: if len(node.out_nodes) > 1 and all([ _is_permute_op(cn) for cn in self._dev_graph.children(node) ]): need_merge = True order = None for trans_node in self._dev_graph.children(node): if order is not None: if order != trans_node.node_attr( trans_node.op.AttrName.ORDER): need_merge = False break else: order = trans_node.node_attr( trans_node.op.AttrName.ORDER) if need_merge: reserverd_node = None for trans_node in self._dev_graph.children(node): if trans_node not in nodes_need_to_remove: reserverd_node = trans_node if reserverd_node is None: reserverd_node = self._dev_graph.children(node)[0] for trans_node in self._dev_graph.children(node): if trans_node is not reserverd_node and trans_node in reserverd_nodes: remove_nodes.append(trans_node) out_tensor = trans_node.out_tensors[0] out_tensor.replace_uses_with( reserverd_node.out_tensors[0]) for node in remove_nodes: node.destroy() if remove_nodes: self._dev_graph.reconnect_nodes() source_nodes = [] for node in self._dev_graph.nodes: if not node.in_tensors: source_nodes.append(node) transpose_group = [] reserverd_nodes = [] visited = [] for source in source_nodes: merge_father_and_child(source, visited, transpose_group, reserverd_nodes) nodes_need_to_remove = [ node for node in nodes_need_to_remove if node not in reserverd_nodes ] for node in reserverd_nodes: order = node.node_attr(node.op.AttrName.ORDER) keep_order = True if any([index != dim for index, dim in enumerate(order)]): keep_order = False if keep_order: nodes_need_to_remove.append(node) for node in nodes_need_to_remove: self._dev_graph.remove_node(node) merge_brothers(reserverd_nodes) # debug # print("#####finalize######") # for node in self._dev_graph.nodes: # print(node.op.type, node.name, node.transpose_out_order) def delete_transpose_of_correlation(self): nodes_need_to_delete_for_special_ops = [] nodes_need_to_insert_aster_special_ops = [] nodes_need_to_merge_for_special_ops = [] for node in self._dev_graph.nodes: if node.op.type == NNDCT_OP.MEAN and not node.node_attr( node.op.AttrName.KEEP_DIMS ) and self._dev_graph.parents(node): pn = self._dev_graph.parents(node)[0] if pn.in_tensors and _is_permute_op( pn) and self._dev_graph.parents(pn): gpn = self._dev_graph.parents(pn)[0] if gpn.op.type in [ NNDCT_OP.CORRELATION1D_ELEMWISE, NNDCT_OP.CORRELATION2D_ELEMWISE ] and node.out_tensors[0].ndim and gpn.out_tensors[ 0].ndim == 5 and node.out_tensors[0].ndim == 4: nodes_need_to_delete_for_special_ops.append(pn) node.transpose_in_order = tuple( _find_swim_order(5)) node.transpose_out_order = tuple( _find_swim_order(4)) special_ops_fn[node.op.type]( node, node.transpose_in_order) nodes_need_to_insert_aster_special_ops.append(node) index = 0 for node in nodes_need_to_insert_aster_special_ops: cn = self._dev_graph.children(node)[0] node_name = "_".join([node.name, "sink_transpose", f"{index}"]) op = base_op.Permute(NNDCT_OP.PERMUTE) new_node = Node(node_name, op=op, dtype=node.dtype, in_quant_part=node.in_quant_part) new_node.set_node_attr(new_node.op.AttrName.ORDER, tuple(_find_sink_order(4))) self._dev_graph.insert_node_between_nodes(new_node, node, cn) nodes_need_to_merge_for_special_ops.append(new_node) index += 1 for node in nodes_need_to_delete_for_special_ops: self._dev_graph.remove_node(node) source_nodes = [] for node in self._dev_graph.nodes: if not node.in_tensors: source_nodes.append(node) transpose_group = [] reserverd_nodes = [] visited = [] for source in nodes_need_to_merge_for_special_ops: merge_father_and_child(source, visited, transpose_group, reserverd_nodes) nodes_need_to_merge_for_special_ops = [ node for node in nodes_need_to_merge_for_special_ops if node not in reserverd_nodes ] for node in reserverd_nodes: order = node.node_attr(node.op.AttrName.ORDER) keep_order = True if any([index != dim for index, dim in enumerate(order)]): keep_order = False if keep_order: nodes_need_to_merge_for_special_ops.append(node) for node in nodes_need_to_merge_for_special_ops: self._dev_graph.remove_node(node) merge_brothers(reserverd_nodes) delete_transpose_of_correlation(self)