def find_direct_connects(self, G, node, has_modified_graph, find_output=True): # traverse reshapes or transposes that do nothing - check gen # find edges connected to concats res = self.find_split_concat(G, node, find_output=find_output) if res is None: return has_modified_graph if G.quantization: qrec = G.quantization[NodeId(node)] for idx, bundle in enumerate(res): if not bundle: continue has_modified_graph = True copy_node = CopyParameters("%s_copy_%s" % (node.name, idx)) for edge_set in bundle: first_edge = edge_set[0] G.remove_edge(first_edge) LOG.info('inserting copy between %s/%s and %s/%s', node.name, idx, first_edge.to_node.name, first_edge.to_idx) G.add_edge( NNEdge(copy_node, first_edge.to_node, to_idx=first_edge.to_idx)) G.add_edge(NNEdge(node, copy_node, from_idx=idx)) if G.quantization: G.quantization[NodeId(copy_node)] = QRec.copy_ktype( qrec, in_qs=[deepcopy(qrec.out_qs[idx])], out_qs=[deepcopy(qrec.out_qs[idx])]) return True
def replace_function(self, G: GraphView, subgraph: GraphView): relu_node = None constant_node = None mul_node = None for node in subgraph.nodes(): if isinstance(node, ReluActivationParameters): relu_node = node elif isinstance(node, ConstantInputParameters): constant_node = node elif isinstance(node, MatrixMulParameters): mul_node = node activation = HSigmoidActivationParameters(mul_node.name + "_fused_close_hsigmoid", offset=0) if G.quantization: reluqrec = G.quantization[NodeId(relu_node)] mulqrec = G.quantization[NodeId(mul_node)] del G.quantization[NodeId(constant_node)] pqrec = QRec.copy_ktype(reluqrec, in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) G.quantization[NodeId(activation)] = pqrec return activation, None, None
def insert_copy_on_common_concat_in(self, G, concat_nodes): # in every concat nodes collect all the in edges (from_node, from_idx) # if there are repetition of tuples, insert a copy in every repetition # different concats cannot have the same in edge (from_node, from_idx) concat_in_edges = [] has_modified_graph = False for concat_node in concat_nodes: for idx, in_edge in enumerate(G.indexed_in_edges( concat_node.name)): real_in_edge = find_real_in_edge(G, in_edge) if real_in_edge in concat_in_edges: has_modified_graph = True copy_node = CopyParameters("%s_copy_%s" % (concat_node.name, idx)) G.remove_edge(in_edge) LOG.info( 'common_concat: inserting copy between %s/%s and %s/%s', in_edge.from_node.name, idx, concat_node.name, in_edge.to_idx) G.add_edge( NNEdge(in_edge.from_node, copy_node, from_idx=in_edge.from_idx)) G.add_edge( NNEdge(copy_node, concat_node, to_idx=in_edge.to_idx)) if G.quantization: qrec = G.quantization[NodeId(concat_node)] G.quantization[NodeId(copy_node)] = QRec.copy_ktype( qrec, in_qs=[deepcopy(qrec.in_qs[idx])], out_qs=[deepcopy(qrec.in_qs[idx])]) else: concat_in_edges.append(real_in_edge) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())): node_list = self.get_node_list(G, node, FusionMatch(self._default_ktype)) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for snode in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=snode)) last_node = snode # assumption here is that the first node could have multiple inputs but definitely has only # one output input_mapping = [[ (node_list.node, idx) ] for idx in range(G.num_in_edges(node_list.node.name))] output_mapping = [(last_node, 0)] pnode = node_list.fusions_class(node_list.node.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for fnode in pnode.contained_nodes(): G.quantization.move_to_fusion(fnode, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.node.name) out_edges = G.out_edges(last_node.name) for snode in node_list.order: G.remove(snode) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def copy_qrec(self, from_node, from_dir, from_idx, to_node): from_qrec = self.qset.get(NodeId(from_node)) if from_qrec is None: raise ValueError( f'trying to copy qrec from {from_node.name} to {to_node.name} - node has no qrec' ) qtype = deepcopy(getattr(from_qrec, f'{from_dir}_qs')[from_idx]) self.qset[NodeId(to_node)] = QRec.copy_ktype(from_qrec, in_qs=[qtype], out_qs=[qtype])
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 else: valid_activations = VALID_ACTIVATIONS_SQ8 for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]: node_list = self.get_node_list(G, fc_node, valid_activations) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.linear, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = LinearFusionParameters( node_list.linear.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype( qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.linear.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): something_changed = False for relu_node in [node for node in G.nodes(node_classes=ReluActivationParameters) if node.upper_bound == 6]: out_edges = G.out_edges(relu_node) if len(out_edges) != 1 or not isinstance(out_edges[0].to_node, MatrixMulParameters): continue mul_node = out_edges[0].to_node in_edges = G.in_edges(mul_node) if len(in_edges) != 2: continue other_edge = (set(in_edges) - {out_edges[0]}).pop() constant_node = other_edge.from_node if len(G.out_edges(constant_node)) != 1: continue if (not isinstance(constant_node, ConstantInputParameters) or not check_equals(G, constant_node, 1.0/6.0)): continue something_changed = True activation = HSigmoidActivationParameters( G.unique_name(f'{mul_node.name}_hsigmoid'), offset=0) in_edges = G.in_edges(relu_node) out_edges = G.out_edges(mul_node) nodes_to_replace = [relu_node, mul_node, constant_node] LOG.info(f'fusing {", ".join(node.name for node in nodes_to_replace)} into HSIGMOID {activation.name}') G.remove_all(nodes_to_replace) for in_edge in in_edges: G.add_edge(NNEdge.clone(in_edge, to_node=activation, to_idx=0)) for out_edge in out_edges: G.add_edge(NNEdge.clone( out_edge, from_node=activation, from_idx=0)) if G.quantization: reluqrec = G.quantization[NodeId(relu_node)] mulqrec = G.quantization[NodeId(mul_node)] del G.quantization[NodeId(constant_node)] pqrec = QRec.copy_ktype( reluqrec, in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) G.quantization[NodeId(activation)] = pqrec return something_changed
def remove_known_batch_dimension(cls, G, x, node, batch_axis=0): x_shape = x[2].shape if x_shape[batch_axis] is not None: if x_shape[0] > 1: raise ValueError( f'multi batch (n={x_shape[batch_axis]}) operations are not supported by {node.name}') rparams = ReshapeParameters( f'{node.name}_batch', old_shape=Dim.unnamed(x_shape), shape=Dim.unnamed(x_shape[0:batch_axis:]+x_shape[batch_axis+1::])) if G.quantization: qrec = G.quantization[NodeId(x[0])] G.quantization[NodeId(rparams)] = QRec.copy_ktype( qrec, in_qs=[qrec.out_qs[0]], out_qs=[qrec.out_qs[0]]) G.add_edge( NNEdge(from_node=x[0], to_node=rparams, from_idx=x[1], to_idx=0)) return (rparams, 0, ProvisionalDim(x_shape[0:batch_axis:]+[None]+x_shape[batch_axis+1::])) else: return x
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 valid_activations_wo_pool = VALID_ACTIVATIONS_POW2_WO_POOL else: valid_activations = VALID_ACTIVATIONS_SQ8 valid_activations_wo_pool = VALID_ACTIVATIONS_SQ8_WO_POOL for pool_node in G.nodes(node_classes=(PoolingParameters, GlobalPoolingParameters)): node_list = self.get_node_list(G, pool_node, valid_activations, valid_activations_wo_pool) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.pool, 0)]] output_mapping = [(last_node, 0)] pnode = ActivationFusion(node_list.pool.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) if isinstance(node, GlobalPoolingParameters): # Global pooling fused with activations need to have only the activation scale G.quantization[NodeId(pnode, node)].out_qs[0] = deepcopy( G.quantization[NodeId( pnode, node)].in_qs[0]) G.quantization[NodeId( pnode, node)].out_qs[0].dtype = np.int32 G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.pool.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 else: valid_activations = VALID_ACTIVATIONS_SQ8 for conv_node in [ params for params in G.nodes() if isinstance(params, Conv2DParameters) ]: node_list = self.get_node_list(G, conv_node, valid_activations) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': # NOTE: This is only for old POW2 kernels - SQ8 can handle this if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, in_dims=deepcopy(node_list.conv.in_dims), out_dims=deepcopy(node_list.order[-1].out_dims), input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: # TODO - stats prec = QRec.copy_ktype(qrecs[0], in_qs=deepcopy(qrecs[0].in_qs), out_qs=deepcopy(qrecs[-1].out_qs)) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for pad_node in [ params for params in G.nodes() if isinstance(params, PadParameters) ]: node_list = self.get_node_list(G, pad_node) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx subgraph.add_edge( NNEdge(from_node=node_list.pad, to_node=node_list.add, to_idx=padded_input_idx)) last_node = node_list.add node_list.add.force_quantized_index = 0 if node_list.active: subgraph.add_edge( NNEdge(from_node=node_list.add, to_node=node_list.active)) last_node = node_list.active if padded_input_idx == 0: input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]] else: input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]] output_mapping = [(last_node, 0)] pnode = PaddedAddFusionParameters( "PADDED_" + node_list.add.name, fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) # TODO - stats if qrecs: prec = QRec.copy_ktype(qrecs[1], in_qs=qrecs[1].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec if padded_input_idx == 0: in_edges = G.in_edges(node_list.pad.name) + \ G.indexed_in_edges(node_list.add.name)[1::] else: in_edges = G.indexed_in_edges( node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph