def _common1_11(cls, node, **kwargs): axis = node.attrs.get('axis', 1) all_nodes = kwargs['all_nodes'] G = kwargs['G'] valid_name = kwargs['valid_name'] x = all_nodes[node.input[0]] x_shape = x[2].shape if axis < 0: axis += len(x_shape) old_shape = cls._get_real_dim(x_shape) # v 1 and 11 work differently to v13. In v1 and v11 the input is collected into a 2D tensor # based on the axis [a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}] with axis k # becomes [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}] # This is used for the softmax new_pshape = [condense(x_shape[:axis:]), condense(x_shape[axis::])] new_shape = cls._get_real_dim(new_pshape) reshape_1 = ReshapeParameters(valid_name + "_reshape1", old_shape=old_shape, shape=new_shape) G.add_edge( NNEdge(from_node=x[0], to_node=reshape_1, from_idx=x[1], to_idx=0)) # operation axis will either be 1 or 0 softmax = SoftMaxParameters(valid_name, axis=len(new_shape) - 1) G.add_edge(NNEdge(from_node=reshape_1, to_node=softmax)) reshape_2 = ReshapeParameters(valid_name + "_reshape2", old_shape=new_shape, shape=old_shape) G.add_edge(NNEdge(from_node=softmax, to_node=reshape_2)) all_nodes[node.output[0]] = (reshape_2, 0, ProvisionalDim(x_shape)) return softmax
def add_constants(G, sub_g): """ adds scalar constants to the subgraphs. If a constant is used in more than one place then it is duplicated """ for node in sub_g.nodes(): for edge in G.in_edges(node.name): if not isinstance(edge.from_node, ConstantInputParameters ) or edge.from_node.out_dims[0].size() > 1: continue const_node = edge.from_node out_edges = G.out_edges(const_node.name) # if constant is connected to more than one node then duplicate it if len(out_edges) > 1: new_const = ConstantInputParameters( G.unique_name(f'{const_node}_dup'), value=const_node.value.copy(), dims=const_node.dims.clone()) G.remove_edge(edge) G.add_edge( NNEdge(from_node=new_const, to_node=edge.to_node, to_idx=edge.to_idx)) sub_g.add_edge( NNEdge(from_node=new_const, to_node=edge.to_node, to_idx=edge.to_idx)) else: sub_g.add_edge( NNEdge(from_node=const_node, to_node=edge.to_node, to_idx=edge.to_idx))
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]: node_list = self.get_node_list(G, conv_node) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def do_remove(self, args: argparse.Namespace): """Removes all the edges and nodes between two node. Will only work if nodes do not affect shape of tensor.""" self._check_graph() if any(node not in self.G for node in args.nodes): self.perror("node not found in graph") return node_from = self.G[args.nodes[0]] if len(args.nodes) == 1: if args.up: out_edges = self.G.indexed_out_edges(node_from.name) self.remove_to_input(node_from) for idx, edge_group in enumerate(out_edges): in_node = self.G.add_input(node_from.out_dims[idx]) self.pfeedback(f'adding input {in_node.name}') for edge in edge_group: self.G.add_edge(NNEdge(from_node=in_node, to_idx=edge.to_idx, to_node=edge.to_node)) else: in_edges = self.G.in_edges(node_from.name) self.remove_to_output(node_from) for edge in in_edges: out_node = self.G.add_output() self.pfeedback(f'adding output {out_node.name}') self.G.add_edge(NNEdge(from_node=edge.from_node, from_idx=edge.from_idx, to_node=out_node)) else: node_to = self.G[args.nodes[1]] edge_from = self.G.out_edges(node_from.name) edge_to = self.G.in_edges(node_to.name) if len(edge_from) != 1: self.perror("node from has more than one out edge") return edge_from = edge_from[0] if len(edge_to) != 1: edge_to = self.find_to(edge_from, set(edge_to)) if edge_to is None: self.perror("nodes don't seem to be connected") return else: edge_to = edge_to[0] if edge_from == edge_to: self.perror("nodes are directly connected") return try: edges = paths_from(self.G, node_from, node_to) except ValueError as ex: self.perror(ex) return edges.remove(edge_from) remove_nodes(self.G, edges) edge_from.to_node = edge_to.to_node edge_from.to_idx = edge_to.to_idx self.G.add_edge(edge_from) self.G.add_dimensions()
def _import_nodes(self, G, graph, handlers, all_nodes, outputs, opts): for node in graph.nodes: handler = handlers.get(node.op_name, None) if not handler: raise ValueError("no handler found for %s" % node.op_type) if node.is_custom and handler: handler = handler.get(node.custom_op_name, None) if not handler: raise ValueError("no handler found for custom operation %s" % node.custom_op_name) params = handler.handle(node, all_nodes=all_nodes, G=G, opts=opts, importer=self) if params is None: continue for idx, out_tensor in enumerate(node.output): output = outputs.get(out_tensor) if not output: continue G.add_edge(NNEdge(from_node=params, to_node=output[0], from_idx=idx, to_idx=output[1])) if opts.get('load_quantization'): qtype = deepcopy(G.quantization[NodeId(params)].out_qs[idx]) G.quantization[NodeId(output[0])] = MultQuantizationRecord( in_qs=[qtype], out_qs=[qtype] )
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = False candidates = [node for node in G.nodes() if len(G.indexed_out_edges(node.name)) == 1 and len(G.out_edges(node.name)) > 1] while candidates: node = candidates.pop(0) strings = self.explore(G, [node]) if not strings: continue modified_graph = True primary = strings.pop(0) for pnode in primary: if pnode in candidates: candidates.remove(pnode) out_edges = [] for other in strings: out_edges.extend(G.out_edges(other[-1].name)) for other_node in other: if other_node in candidates: candidates.remove(other_node) G.remove(other_node) nid = NodeId(other_node) if G.quantization and nid in G.quantization: del G.quantization[nid] LOG.info( f'removed duplicates from {primary[0].name} {",".join(node.name for node in other)}') pend = primary[-1] for edge in out_edges: G.add_edge( NNEdge(from_node=pend, to_node=edge.to_node, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return modified_graph
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] valid_name = kwargs['valid_name'] G = kwargs['G'] constant_operation = kwargs.get('constant_operation') inputs = [all_nodes[inp] for inp in node.input] # may have more than one input i.e. clip x = inputs[0] if cls.is_constant(x) and constant_operation: res = constant_operation(cls.get_constant(x)) if res.size < 10: logger.info("reducing %s to a constant %s", valid_name, res) else: logger.info("reducing %s to a constant", valid_name) params = ConstantInputParameters(valid_name, value=res, constant_store=G.constant_store) else: params_args = kwargs.get('params_args', {}) params = kwargs['params_class'](valid_name, **params_args) G.add_edge( NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) all_nodes[node.output[0]] = (params, 0, copy.deepcopy(x[2])) return params
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] opts = kwargs['opts'] qrec_class = kwargs.get('qrec_class') params_args = kwargs.get('params_args', {}) constant_operation = kwargs.get('constant_operation') inputs = [all_nodes[inp] for inp in node.input] assert len(inputs) == 2 if all(cls.is_constant(inp) for inp in inputs) and constant_operation: LOG.info("reducing %s to a constant", node.name) values = [cls.get_constant(inp) for inp in inputs] output_shapes = cls.implied_broadcast(inputs) params = ConstantInputParameters(node.name, value=constant_operation(*values), dims=Dim.unnamed(output_shapes[0].known_shape), constant_store=G.constant_store) else: params = kwargs['params_class'](node.name, **params_args) output_shapes = cls.implied_broadcast(inputs) shapes = [] for idx, inp in enumerate(inputs): G.add_edge(NNEdge(from_node=inp[0], to_node=params, from_idx=inp[1], to_idx=idx)) shapes.append(inp[2].known_shape) if isinstance(params, Broadcastable): params.set_broadcast(shapes) if opts.get('load_quantization'): G.quantization[NodeId(params)] = cls.load_tf_quantization( node.input, node.output, qrec_class=qrec_class) all_nodes[node.output[0]] = (params, 0, output_shapes[0]) return params
def match(self, G: GraphView, set_identity: bool = True): has_modified = False for node in G.nodes(node_classes=ConstantInputParameters): out_edges = G.out_edges(node.name) if len(out_edges) <= 1: continue has_modified = True LOG.info( 'node %s has more than one out edge and will be duplicated', node.name) idx = 1 for out_edge in out_edges[1::]: new_constant = ConstantInputParameters(f'{node.name}_{idx}', dims=Dim.unnamed( node.dims.shape), value=node.value.copy()) G.remove_edge(out_edge) G.add_edge( NNEdge(from_node=new_constant, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) idx += 1 if set_identity: self.set_identity(G) return has_modified
def _import_nodes(self, G, graph, handlers, all_nodes, outputs, opts): used_tensors = set(all_nodes.keys()) | set( outputs.keys()) | set.union(*(set(node.input) for node in graph.node)) used_tensors.discard('') vars_dict = {} for node in graph.node: handler = handlers[node.domain].get( node.op_type, None) if node.domain in handlers else None if not handler: raise ValueError("no handler found for %s" % node.op_type) params = handler.handle(OnnxNode(node), all_nodes=all_nodes, vars_dict=vars_dict, G=G, valid_name=self._node_name(node), opts=opts, used_tensors=used_tensors) if params is None: continue # some handlers set the meta information if 'onnx_output' not in params.meta: params.meta['onnx_output'] = list(node.output) for out_name in node.output: output = outputs.get(out_name) if not output: continue # extra this from all nodes since some handlers add multiple nodes producer = all_nodes[out_name] G.add_edge( NNEdge(from_node=producer[0], to_node=output[0], from_idx=producer[1], to_idx=output[1]))
def _common(cls, node, **kwargs): params_class = kwargs['params_class'] params_args = kwargs.get('params_args', {}) flatten = kwargs.get('flatten') if params_args is None: params_args = {} all_nodes = kwargs['all_nodes'] opts = kwargs['opts'] G = kwargs['G'] inputs = [all_nodes[inp] for inp in node.input] assert len(inputs) == 1 inp = inputs[0] pout = inp[2].flatten if flatten else copy.deepcopy(inp[2]) params = params_class(node.name, **params_args) if opts.get('load_quantization'): in_qs = kwargs['in_qs'] if "in_qs" in kwargs else None out_qs = kwargs['out_qs'] if "out_qs" in kwargs else None G.quantization[NodeId(params)] = cls.load_tf_quantization( [node.input[0]], node.output, in_qs=in_qs, out_qs=out_qs) G.add_edge( NNEdge(from_node=inp[0], to_node=params, from_idx=inp[1], to_idx=0)) all_nodes[node.output[0]] = (params, 0, pout) return params
def _common(cls, node, **kwargs): params_class = kwargs['params_class'] opts_class = kwargs['opts_class'] node_opts = node.get_options(opts_class) all_nodes = kwargs['all_nodes'] opts = kwargs['opts'] G = kwargs['G'] inputs = [all_nodes[inp] for inp in node.input] x = inputs[0] new_shape = tuple(cls._verify_constant(inputs[1])) params = params_class(node.name, new_shape=new_shape, align_corners=node_opts.AlignCorners(), halfpixel_centers=node_opts.HalfPixelCenters(), in_dims_hint=[['h', 'w', 'c']], out_dims_hint=[['h', 'w', 'c']]) out_shape = params.get_output_size([Dim.unnamed(x[2].known_shape)])[0] if opts.get('load_quantization'): G.quantization[NodeId(params)] = cls.load_tf_quantization( [node.input[0]], [node.output[0]]) G.add_edge( NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) all_nodes[node.output[0]] = (params, 0, x[2].infer_mapping(out_shape.shape)) return params
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] valid_name = kwargs['valid_name'] inputs = [all_nodes[inp] for inp in node.input] fold_batchnorm = kwargs['opts'].get('fold_batchnorm', True) x = inputs[0] x_shape = x[2].shape x_rank = len(x_shape) if x_rank > 4: raise ValueError("only 1D and 2D batch normalization is supported") momentum = node.attrs.get("momentum", 0.9) epsilon = node.attrs.get("epsilon", 0.9) bn_scale = cls.get_constant(inputs[1]) bn_bias = cls.get_constant(inputs[2]) running_mean = cls.get_constant(inputs[3]) running_variance = cls.get_constant(inputs[4]) # from version 7, force to use test mode if cls.SINCE_VERSION >= 7 or node.attrs.get("is_test", 0): spatial = None else: spatial = node.attrs.get("spatial", 1) == 1 if fold_batchnorm and isinstance(x[0], Conv2DParameters): conv = x[0] weights = conv.weights if conv.has_bias: biases = conv.biases else: biases = np.zeros([weights.shape[0]]) conv.has_bias = True # fold batch norm into conv weights and biases w_conv = weights.copy().reshape(weights.shape[0], -1) w_bn = np.diag(bn_scale / np.sqrt(epsilon + running_variance)) w_conv = np.matmul(w_bn, w_conv).reshape(weights.shape) b_bn = bn_bias - bn_scale * running_mean / np.sqrt( running_variance + epsilon) conv.weights = w_conv conv.biases = biases + b_bn all_nodes[node.output[0]] = x else: params = BatchNormalizationParameters( valid_name, scale=bn_scale, bias=bn_bias, running_mean=running_mean, running_variance=running_variance, spatial=spatial, momentum=momentum, epsilon=epsilon) G.add_edge( NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) all_nodes[node.output[0]] = (params, 0, deepcopy(x[2])) return params
def conv(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] valid_name = kwargs['valid_name'] inputs = [all_nodes[inp] for inp in node.input] # input N x C x H x W x = inputs[0] x_rank = len(x[2].shape) x_shape = x[2].shape spatial_size = x_rank - 2 assert spatial_size <= 2, "only 1D and 2D convolutions supported" # M x C/group x kH x kW weights = cls.get_constant(inputs[1]) out_c = weights.shape[0] group = node.attrs.get("group", 1) in_c = x_shape[1] filt_in_c = in_c // group filt_h = weights.shape[2] filt_w = weights.shape[2] h = 1 if spatial_size <= 1 else x_shape[2] w = 1 if spatial_size == 0 else (x_shape[2] if spatial_size == 1 else x_shape[3]) filt_dim = Conv2DFilterDim(filt_h, filt_w, out_c, in_c=filt_in_c) filt_dim = filt_dim.impose_order(cls.ONNX_FILTER_ORDER) if len(inputs) > 2: biases = cls.get_constant(inputs[2]) else: biases = np.zeros([out_c]) dilations = cls.pad_start_with(node.attrs.get("dilations", [1] * spatial_size), [1], 2) strides = cls.pad_start_with(node.attrs.get("strides", [1] * spatial_size), [1], 2) pad_dim = cls.calc_pad_dim(node, spatial_size) params = Conv2DParameters(valid_name, filt=filt_dim, stride=StrideDim(strides[0], strides[1]), dilation=DilationDim(dilations[0], dilations[1]), groups=group, padding=pad_dim, has_bias=True, in_dims_hint=SparseList([['c', 'h', 'w']]), out_dims_hint=SparseList([['c', 'h', 'w']]), constant_store=G.constant_store) params.weights = weights params.biases = biases in_dim = Dim.named_ordered(c=in_c, h=h, w=w) out_dims = params.get_output_size([in_dim]) pout_dims = ProvisionalDim([x_shape[0]] + out_dims[0].shape) G.add_edge(NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) all_nodes[node.output[0]] = (params, 0, pout_dims) return params
def pool(cls, node, pool_type=None, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] valid_name = kwargs['valid_name'] inputs = [all_nodes[inp] for inp in node.input] x = inputs[0] x_shape = x[2].shape x_feature_shape = x_shape[2::] in_c = x_shape[1] kernel_shape = node.attrs["kernel_shape"] spatial_size = len(kernel_shape) x_rank = spatial_size + 2 if spatial_size != 2: raise ValueError(valid_name + " with {}D input".format(x_rank)) h = x_shape[2] w = x_shape[3] strides = node.attrs.get("strides", [1] * spatial_size) stride_is_one = all(stride == 1 for stride in strides) dilations = node.attrs.get("dilations", [1] * spatial_size) if any(dilation > 1 for dilation in dilations): raise ValueError(valid_name + " with dilation not supported") # ceil_mode = bool(node.attrs.get("ceil_mode", 0)) pad_dim = cls.calc_pad_dim(node, spatial_size) # Note: This needs to check dilation if it is added filter_matches_input = (all( k_dim >= (x_dim + pad) for k_dim, x_dim, pad in zip( kernel_shape, x_feature_shape, [pad_dim.h, pad_dim.w]))) if filter_matches_input and stride_is_one: params = GlobalPoolParameters(valid_name, pool_type=pool_type, axis=[1, 2], keep_dims=True, in_dims_hint=[['c', 'h', 'w']], out_dims_hint=[['c', 'h', 'w']]) else: params = PoolingParameters( valid_name, filt=PoolFilterDim(kernel_shape[0], kernel_shape[1]), stride=StrideDim(strides[0], strides[1]), padding=pad_dim, pool_type=pool_type, in_dims_hint=[['c', 'h', 'w']], out_dims_hint=[['c', 'h', 'w']]) in_dim = Dim.named_ordered(c=in_c, h=h, w=w) out_dims = params.get_output_size([in_dim]) pout_dims = ProvisionalDim([x_shape[0]] + out_dims[0].shape) G.add_edge( NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) all_nodes[node.output[0]] = (params, 0, pout_dims) return params
def _common(cls, node, **kwargs): node_opts = node.get_options(FullyConnectedOptions) G = kwargs['G'] opts = kwargs['opts'] all_nodes = kwargs['all_nodes'] inputs = [all_nodes[t] for t in node.input] x = inputs[0] x_shape = x[2].shape x_known_shape = x[2].known_shape inp_sz = np.prod(np.array(x_known_shape)) weights = inputs[1] weights_shape = weights[2].shape out_c = weights_shape[0] filt_dim = FcFilterDim(weights_shape[0], *x_known_shape) node.input[1].used = True check(filt_dim.sz == inp_sz, "filter doesn't match input size") if len(node.input) > 2: node.input[2].used = True keep_dims = node_opts.KeepNumDims() in_hint = [str(i) for i in range(len(x_known_shape) - 1)] + ['c'] out_hint = in_hint.copy() if keep_dims else ['c'] params = FcParameters(node.name, filt=filt_dim, has_bias=True, in_dims_hint=SparseList([in_hint]), out_dims_hint=SparseList([out_hint]), constant_store=G.constant_store, keep_dims=keep_dims) if opts.get('load_dequantized'): cls.load_dequantized_filter_parameters(params, node.input) else: cls.load_filter_parameters(G, params, node.input, node.output, opts) if x_shape[0] is None: out_shape = x_shape[:-1:] + [out_c] if keep_dims else [ x_shape[0], out_c ] else: out_shape = x_known_shape[:-1:] + [out_c] if keep_dims else [out_c] pout_dims = ProvisionalDim(out_shape) G.add_edge( NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) aparams = cls.fuse_activation(node_opts, node.name, params, **kwargs) all_nodes[node.output[0]] = (aparams, 0, pout_dims) return params
def set_c_state_as_output(self, G): output_c_state = G.add_output() lstm_qrec = G.quantization and G.quantization.get(NodeId(self)) if lstm_qrec: c_state_idx = self.INPUT_NAMES.index('c_state') in_q = lstm_qrec.in_qs[c_state_idx] lstm_qrec.out_qs.append(in_q) c_state_q = MultQuantizationRecord(in_qs=[in_q], out_qs=[in_q]) G.quantization[NodeId(output_c_state)] = c_state_q G.add_edge(NNEdge(self, output_c_state, from_idx=1)) G.add_dimensions()
def pool2d(cls, node, pool_type=None, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] opts = kwargs['opts'] node_opts = node.get_options(Pool2DOptions) inputs = [all_nodes[inp] for inp in node.input] x = inputs[0] x = cls.remove_known_batch_dimension(G, x, node) x_shape = x[2].shape in_c = x_shape[1] in_b, h, w, in_c = tuple(x_shape) filt_h = node_opts.FilterHeight() filt_w = node_opts.FilterWidth() stride_h = node_opts.StrideH() stride_w = node_opts.StrideW() pad = cls.get_tf_padding(node_opts.Padding()) filter_matches_input = h == filt_h and w == filt_w stride_is_one = stride_h == 1 and stride_w == 1 if filter_matches_input and stride_is_one: params = GlobalPoolParameters(node.name, pool_type=pool_type, axis=[0, 1], keep_dims=True, in_dims_hint=[['h', 'w', 'c']], out_dims_hint=[['h', 'w', 'c']]) else: params = PoolingParameters(node.name, filt=PoolFilterDim(filt_h, filt_w), stride=StrideDim(stride_h, stride_w), padding=pad, pool_type=pool_type, in_dims_hint=[['h', 'w', 'c']], out_dims_hint=[['h', 'w', 'c']]) if opts.get('load_quantization'): G.quantization[NodeId(params)] = cls.load_tf_quantization( node.input, node.output) in_dim = Dim.named_ordered(h=h, w=w, c=in_c) out_dims = params.get_output_size([in_dim]) pout_dims = ProvisionalDim([in_b] + out_dims[0].shape) G.add_edge( NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) params = cls.fuse_activation(node_opts, node.name, params, **kwargs) all_nodes[node.output[0]] = (params, 0, pout_dims) return params
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] opts = kwargs['opts'] node_opts = kwargs.get("node_opts", None) params_args = kwargs.get('params_args', {}) constant_operation = kwargs.get('constant_operation') inputs = [all_nodes[inp] for inp in node.input] assert len(inputs) == 2 if all(cls.is_constant(inp) for inp in inputs) and constant_operation: LOG.info("reducing %s to a constant", node.name) values = [cls.get_constant(inp) for inp in inputs] output_shapes = cls.implied_broadcast(inputs) params = ConstantInputParameters(node.name, value=constant_operation(*values), dims=Dim.unnamed( output_shapes[0].known_shape), constant_store=G.constant_store) else: params = kwargs['params_class'](node.name, **params_args) output_shapes = cls.implied_broadcast(inputs) shapes = [] for idx, inp in enumerate(inputs): G.add_edge( NNEdge(from_node=inp[0], to_node=params, from_idx=inp[1], to_idx=idx)) shapes.append(inp[2].known_shape) if isinstance(params, Broadcastable): for idx, shape in enumerate(shapes.copy()): len_diff = len(shape) - len(output_shapes[0].known_shape) if len_diff > 0: if not all(dim is None or dim == 1 for dim in shape[:len_diff:]): in_shapes = ",".join( str(shape) for shape in shapes) raise ValueError( f'strange broadcast {in_shapes} -> {output_shapes[0].shape}' ) shapes[idx] = shape[len_diff::] params.set_broadcast(shapes) if opts.get('load_quantization'): G.quantization[NodeId(params)] = cls.load_tf_quantization( node.input, node.output) if node_opts is not None: params = cls.fuse_activation(node_opts, node.name, params, **kwargs) all_nodes[node.output[0]] = (params, 0, output_shapes[0]) return params
def construct_subgraph(G, nodes): subg = GraphView() nodes = set(nodes) for node in nodes: for edge in G.out_edges(node.name): # only add internal edges if edge.to_node in nodes: subg.add_edge(NNEdge(from_node=edge.from_node, to_node=edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) def red_fn(state, edge): state.setdefault((edge.from_node, edge.from_idx), [] ).append((edge.to_node, edge.to_idx)) return state inputs = reduce(red_fn, set([edge for node in subg.inputs() for edge in G.in_edges(node.name)]), {}) inputs_map = [] for (fnode, fidx), outs in inputs.items(): inp = FusionInputParameters( f'{fnode.name}_{fidx}_in', dims=fnode.out_dims[fidx]) inputs_map.append((fnode, fidx)) for (tnode, tidx) in outs: subg.add_edge(NNEdge(from_node=inp, to_node=tnode, to_idx=tidx)) outputs = [(node, set(edge.from_idx for edge in G.out_edges(node.name))) for node in subg.outputs()] outputs_map = [] for (node, fidxes) in outputs: output_map = [] outputs_map.append(output_map) for fidx in fidxes: output_map.append((edge.to_node, edge.to_idx) for edge in G.out_edges(node.name) if edge.from_idx == fidx) outp = FusionOutputParameters( f'{node.name}_{fidx}_out', dims=node.out_dims[fidx]) subg.add_edge(NNEdge(from_node=node, to_node=outp, from_idx=fidx)) return (subg, inputs_map, outputs_map)
def move_constant(cls, G: GraphView, params, in_qs): # looks for a constant on one of the inputs # if there is one we can scale by the second dimension of the second # tensor. If the constant is on the first tensor then move to the second # and transpose the operation in_edges = G.indexed_in_edges(params.name) in1_node = in_edges[0].from_node in2_node = in_edges[1].from_node if isinstance(in2_node, ConstantInputParameters): return in2_node, in_qs elif isinstance(in1_node, ConstantInputParameters): if len(params.in_dims) > 2: # check if the bias has the correct length to move constant # it must have a length equal to the second tensors second dimension after transpose bias_size = params.in_dims[2].size() in1_shape = params.in_dims[0].shape if in1_shape[1] != bias_size: return None, in_qs for edge in in_edges[:2:]: G.remove_edge(edge) to_idx = 1 # swap edges to move constant onto input 2 for edge in in_edges[:2:]: new_edge = NNEdge(from_node=edge.from_node, to_node=edge.to_node, from_idx=edge.from_idx, to_idx=to_idx) G.add_edge(new_edge) to_idx = 1 - to_idx # use A.B = (BT.AT)T identity tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'), transpose=(1, 0)) tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'), transpose=(1, 0)) tout = TransposeParameters(G.unique_name(f'{params.name}_tout'), transpose=(1, 0)) G.insert_node_before(tin1, params) G.insert_node_before(tin2, params, to_idx=1) G.insert_node_after(params, tout) LOG.warning('transposes inserted on %s - rerun adjust', params.name) return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::] else: return None, in_qs
def _common(cls, node, pool_type="max", **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] valid_name = kwargs['valid_name'] inputs = [all_nodes[inp] for inp in node.input] x = inputs[0] x_shape = x[2].shape unknown_dims = sum(1 if dim is None else 0 for dim in x_shape) params = GlobalPoolParameters( valid_name, pool_type=pool_type, axis=tuple(range(1, len(x_shape) - unknown_dims)), keep_dims=True ) pout_dims = ProvisionalDim([x_shape[0], x_shape[1]]) G.add_edge(NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) all_nodes[node.output[0]] = (params, 0, pout_dims) return params
def set_states_as_inputs(self, G): input_nodes = { self.INPUT_NAMES[edge.to_idx]: edge.from_node for edge in G.in_edges(self.name) if isinstance(edge.from_node, ConstantInputParameters) } state_node_names = [ name for name in self.INPUT_NAMES if "state" in name ] for state_node_name in state_node_names: state_node_idx = self.INPUT_NAMES.index(state_node_name) state_node = input_nodes[state_node_name] step_idx = state_node.step_idx G.remove(state_node) state_node = G.add_input(name=state_node_name + "_" + self.name, dim=Dim(list(state_node.value.shape))) state_node.step_idx = step_idx G.add_edge(NNEdge(state_node, self, to_idx=state_node_idx)) G.add_dimensions()
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] valid_name = kwargs['valid_name'] G = kwargs['G'] inputs = [all_nodes[inp] for inp in node.input] x = inputs[0] y = inputs[1] shape = cls.get_constant(y) pshape = cls.broadcast_to(x, shape) if cls.is_constant(x): logger.info("reducing %s to a constant", valid_name) x_val = cls.get_constant(x) params = ConstantInputParameters(valid_name, value=x_val * np.ones(shape)) else: params = ExpandParameters(valid_name, shape=shape) G.add_edge(NNEdge(x[0], params, from_idx=x[1])) all_nodes[node.output[0]] = (params, 0, pshape, x[3]) return params
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] valid_name = kwargs['valid_name'] G = kwargs['G'] constant_operation = kwargs.get('constant_operation') constant_int_operation = kwargs.get('constant_int_operation') inputs = [all_nodes[inp] for inp in node.input] assert len(inputs) == 2 if all(cls.is_constant(inp) for inp in inputs) and constant_operation: values = [cls.get_constant(inp) for inp in inputs] outputs = cls.implied_broadcast(inputs) if constant_int_operation and all( np.issubdtype(val.dtype, np.integer) for val in values): res = constant_int_operation(*values) else: res = constant_operation(*values) if res.size < 10: logger.info("reducing %s to a constant %s", valid_name, res) else: logger.info("reducing %s to a constant", valid_name) params = ConstantInputParameters(valid_name, value=res, dims=Dim.unnamed( outputs[0].known_shape), constant_store=G.constant_store) else: params_args = kwargs.get('params_args', {}) params = kwargs['params_class'](valid_name, **params_args) outputs = cls.implied_broadcast(inputs) shapes = [] for idx, inp in enumerate(inputs): G.add_edge( NNEdge(from_node=inp[0], to_node=params, from_idx=inp[1], to_idx=idx)) shapes.append(inp[2].known_shape) if isinstance(params, Broadcastable): params.set_broadcast(shapes) all_nodes[node.output[0]] = (params, 0, outputs[0]) return params
def remove_known_batch_dimension(cls, G, x, node, batch_axis=0): x_shape = x[2].shape if x_shape[batch_axis] is not None: if x_shape[0] > 1: raise ValueError( f'multi batch (n={x_shape[batch_axis]}) operations are not supported by {node.name}') rparams = ReshapeParameters( f'{node.name}_batch', old_shape=Dim.unnamed(x_shape), shape=Dim.unnamed(x_shape[0:batch_axis:]+x_shape[batch_axis+1::])) if G.quantization: qrec = G.quantization[NodeId(x[0])] G.quantization[NodeId(rparams)] = QRec.copy_ktype( qrec, in_qs=[qrec.out_qs[0]], out_qs=[qrec.out_qs[0]]) G.add_edge( NNEdge(from_node=x[0], to_node=rparams, from_idx=x[1], to_idx=0)) return (rparams, 0, ProvisionalDim(x_shape[0:batch_axis:]+[None]+x_shape[batch_axis+1::])) else: return x
def fuse_activation(cls, tfl_opts, name, params, **kwargs): G = kwargs['G'] opts = kwargs['opts'] ext = hashlib.sha1(name.encode( "UTF-8")).hexdigest()[:8] if opts.get('anonymise') else 'activation' if opts.get('load_quantization') and NodeId(params) in G.quantization: node_qrec = G.quantization[NodeId(params)] else: node_qrec = None # if node_qrec is not None and None in node_qrec.in_qs + node_qrec.out_qs: # # one of the input is a constant or strange behaviour -> may be is something fusions will get rid of # return add_node(self.G, node) aparams = None if tfl_opts.FusedActivationFunction() == ActivationFunctionType.NONE: if node_qrec is not None and node_qrec.ktype.startswith('scaled'): # and opts.get('insert_relus'): # here we have no activation in an asymmetric qtype -> may be an omitted relu if node_qrec.out_qs[0] is not None and node_qrec.out_qs[0].min_val == 0: if np.all(np.round(node_qrec.out_qs[0].max_val) == 6): aparams = ActivationParameters.get_activation( 'relu6', name + f"_{ext}") else: aparams = ActivationParameters.get_activation( 'relu', name + f"_{ext}") else: aparams = ActivationParameters.get_activation(cls.TF_ACTIVATIONS[tfl_opts.FusedActivationFunction()], name + f"_{ext}") if aparams: G.add_edge(NNEdge(from_node=params, to_node=aparams)) if opts.get('load_quantization'): # In between the fused operation and activation the # transfer is in int32 representation node_qrec = G.quantization[NodeId(params)] ina_qtype = deepcopy(node_qrec.out_qs[0]) outa_qtype = deepcopy(ina_qtype) G.quantization[NodeId(aparams)] = QRec.scaled( in_qs=[ina_qtype], out_qs=[outa_qtype]) params = aparams return params
def fuse_activation(cls, tfl_opts, name, params, **kwargs): G = kwargs['G'] opts = kwargs['opts'] if opts.get('load_quantization') and NodeId(params) in G.quantization: node_qrec = G.quantization[NodeId(params)] else: node_qrec = None # if node_qrec is not None and None in node_qrec.in_qs + node_qrec.out_qs: # # one of the input is a constant or strange behaviour -> may be is something fusions will get rid of # return add_node(self.G, node) aparams = None if tfl_opts.FusedActivationFunction() == ActivationFunctionType.NONE: if node_qrec is not None and isinstance( node_qrec, MultQuantizationRecordBase): # here we have no activation in an asymmetric qtype -> may be an omitted relu if node_qrec.out_qs[0].min_val == 0: if np.all(np.round(node_qrec.out_qs[0].max_val) == 6): aparams = ActivationParameters.get_activation( 'relu6', name + "_activation") else: aparams = ActivationParameters.get_activation( 'relu', name + "_activation") else: aparams = ActivationParameters.get_activation( cls.TF_ACTIVATIONS[tfl_opts.FusedActivationFunction()], name + "_activation") if aparams: G.add_edge(NNEdge(from_node=params, to_node=aparams)) if opts.get('load_quantization'): # In between the fused operation and activation the # transfer is in int32 representation node_qrec = G.quantization[NodeId(params)] ina_qtype = deepcopy(node_qrec.out_qs[0]) outa_qtype = deepcopy(ina_qtype) G.quantization[NodeId(aparams)] = MultQuantizationRecord( in_qs=[ina_qtype], out_qs=[outa_qtype]) params = aparams return params
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] valid_name = kwargs['valid_name'] G = kwargs['G'] constant_operation = kwargs.get('constant_operation') inputs = [all_nodes[inp] for inp in node.input] assert len(inputs) == 2 if all(cls.is_constant(inp) for inp in inputs) and constant_operation: logger.info("reducing %s to a constant", valid_name) values = [cls.get_constant(inp) for inp in inputs] params = ConstantInputParameters(valid_name, value=constant_operation(*values)) else: params_args = kwargs.get('params_args', {}) params = kwargs['params_class'](valid_name, **params_args) for idx, inp in enumerate(inputs): G.add_edge( NNEdge(from_node=inp[0], to_node=params, from_idx=inp[1], to_idx=idx)) outputs = cls.implied_broadcast(inputs) all_nodes[node.output[0]] = (params, 0, outputs[0]) return params
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for pad_node in [ params for params in G.nodes() if isinstance(params, PadParameters) ]: node_list = self.get_node_list(G, pad_node) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx subgraph.add_edge( NNEdge(from_node=node_list.pad, to_node=node_list.add, to_idx=padded_input_idx)) last_node = node_list.add node_list.add.force_quantized_index = 0 if node_list.active: subgraph.add_edge( NNEdge(from_node=node_list.add, to_node=node_list.active)) last_node = node_list.active if padded_input_idx == 0: input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]] else: input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]] output_mapping = [(last_node, 0)] pnode = PaddedAddFusionParameters( "PADDED_" + node_list.add.name, fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec if padded_input_idx == 0: in_edges = G.in_edges(node_list.pad.name) + G.indexed_in_edges( node_list.add.name)[1::] else: in_edges = G.indexed_in_edges( node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph