def match(self, G: GraphView, set_identity: bool = True): if not G.quantization: return softmaxes = [ node for node in G.nodes() if isinstance(node, SoftMaxParameters) ] qrecs = [G.quantization[NodeId(node)] for node in softmaxes] if not all(isinstance(qrec, MultQuantizationRecord) for qrec in qrecs): return for softmax, qrec in zip(softmaxes, qrecs): in_q = qrec.in_qs[0] in_q.scale_to_pow2() for edge in G.in_edges(softmax.name): propagate_qtype_up(G, in_q, edge) for edge in G.out_edges(softmax.name): assert isinstance( edge.to_node, (OutputParameters, QuantizeParameters )), "Softmax is supported only at the end of the graph" out_qrec = G.quantization[NodeId(edge.to_node)] out_qrec.in_qs[0] = qrec.out_qs[0] out_qrec.out_qs[0] = qrec.out_qs[0] if set_identity: self.set_identity(G) return False
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): const_ops = [node for node in G.nodes() if isinstance(node, MatrixMulParameters) and any([isinstance(edge.from_node, ConstantInputParameters) and check_equals(G, edge.from_node, 1.0/6.0) for edge in G.in_edges(node.name)])] oprecs = [oprec for oprec in (look_back(G, op) for op in const_ops) if oprec is not None] has_modified_graph = False for oprec in oprecs: mul_edge = G.out_edges(oprec['mul'][0].name) if len(mul_edge) == 1: mul_edge = mul_edge[0] if isinstance(mul_edge.to_node, ReluActivationParameters): oprec['relu3'] = (mul_edge.to_node, G.quantization[NodeId(mul_edge.to_node)]) has_modified_graph = True process_rec(G, oprec) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): if not G.quantization: return for nid in [nid for nid, qrec in G.quantization.sorted_iterator(G) if qrec is None or not (qrec.in_qs and qrec.out_qs)]: if nid.fnode_name: LOG.warning("can't add quantization to fused node %s", nid.fnode_name) continue if nid.node_name not in G: # previous fusions may have removed nodes from the graph continue node = nid.get_node(G) predecessors = [NodeId(pred) for pred in G.predecessors(node.name)] successors = [NodeId(succ) for succs in G.successors(node.name) for succ in succs] go_back = not successors or (predecessors and all(pred in G.quantization for pred in predecessors)) go_forward = not predecessors or (successors and all(succ in G.quantization for succ in successors)) if not (go_back or go_forward): LOG.warning("node %s is not connected to anything and has no quantization", node.name) continue if go_forward: out_qrecs = set(G.quantization[nid] for nid in successors) if not all(isinstance(out_qrec, MultQuantizationRecord) for out_qrec in out_qrecs): continue out_qtypes = reduce_qtypes([(edge.from_idx, G.quantization[NodeId(edge.to_node)].in_qs[edge.to_idx]) for edge in G.out_edges(node.name)]) else: out_qtypes = None if go_back: in_qrecs = set(G.quantization[nid] for nid in predecessors) if not all(isinstance(in_qrec, MultQuantizationRecord) for in_qrec in in_qrecs): continue in_qtypes = reduce_qtypes([(edge.to_idx, G.quantization[NodeId(edge.from_node)].out_qs[edge.from_idx]) for edge in G.in_edges(node.name)]) else: in_qtypes = None if not in_qtypes: if not predecessors: LOG.info("setting quantization on input node %s", node.name) qrec = MultQuantizationRecord(in_qs=deepcopy(out_qtypes), out_qs=deepcopy(out_qtypes)) else: raise NotImplementedError("propagating qrecs not implemented") elif not out_qtypes: if not successors: LOG.info("setting quantization on output node %s", node.name) qrec = MultQuantizationRecord(in_qs=deepcopy(in_qtypes), out_qs=deepcopy(in_qtypes)) else: raise NotImplementedError("propagating qrecs not implemented") else: LOG.info("setting quantization on node %s", node.name) qrec = MultQuantizationRecord(in_qs=deepcopy(in_qtypes), out_qs=deepcopy(out_qtypes)) G.quantization[nid] = qrec if set_identity: self.set_identity(G) return False
def match(self, G: GraphView, set_identity: bool = True): if not G.quantization: return sigs_swishes = [ node for node in G.nodes() if isinstance(node, (SigmoidActivationParameters, HSigmoidActivationParameters, HSwishActivationParameters)) ] qrecs = [G.quantization[NodeId(node)] for node in sigs_swishes] for sig_swish, qrec in zip(sigs_swishes, qrecs): in_edge = [ edge for edge in G.in_edges(sig_swish.name) if edge.to_idx == 0 ][0] in_q = qrec.in_qs[0] min_val, max_val = in_q.min_val, in_q.max_val if isinstance( sig_swish, (HSigmoidActivationParameters, SigmoidActivationParameters)): # Hard sigmoid implements a RELU, be sure 6 can be representable min_val, max_val = 0, 6 elif isinstance(sig_swish, HSwishActivationParameters): min_val, max_val = 0, in_q.max_val * 6 new_in_q = QType.from_min_max_sq(min_val=min_val, max_val=max_val, dtype=in_q.dtype) propagate_qtype_up(G, new_in_q, in_edge) if set_identity: self.set_identity(G) return False
def match(self, G: GraphView, set_identity: bool = True): # get a list of all the nodes that are transposable but not transposes # Need to do this first to avoid mutating it when doing the modifications tnodes = list(filter(lambda n: isinstance(n, Transposable) and\ not isinstance(n, TransposeParameters), G.nodes())) for node in tnodes: if node.transpose_in: for idx, edge in enumerate(G.in_edges(node.name)): in_params = TransposeParameters("%s_TIN_%s" % (node.name, idx), transpose=node.transpose_in) if node.in_dims_hint: in_hint = node.in_dims_hint[edge.to_idx] out_hint = apply_reverse_transpose_to_hint(in_hint, node.transpose_in) in_params.in_dims_hint = [in_hint.copy()] in_params.out_dims_hint = [out_hint.copy()] node.in_dims_hint[edge.to_idx] = out_hint G.insert_node(in_params, edge.from_node.name, edge.to_node.name, from_idx=edge.from_idx, to_idx=edge.to_idx) node.transpose_in = None if node.transpose_out: for idx, edge in enumerate(G.out_edges(node.name)): out_params = TransposeParameters("%s_TOUT_%s" % (node.name, idx), transpose=node.transpose_out) if node.out_dims_hint: out_hint = node.out_dims_hint[edge.from_idx] in_hint = apply_reverse_transpose_to_hint(out_hint, node.transpose_out) out_params.in_dims_hint = [in_hint.copy()] out_params.out_dims_hint = [out_hint.copy()] node.out_dims_hint[edge.from_idx] = in_hint G.insert_node(out_params, edge.from_node.name, edge.to_node.name, from_idx=edge.from_idx, to_idx=edge.to_idx) node.transpose_out = None if set_identity: self.set_identity(G)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for node in [ node for node in G.nodes() if self.node_does_nothing(G, node) ]: has_modified_graph = True in_edge = G.in_edges(node.name)[0] G.remove_edge(in_edge) for out_edge in G.out_edges(node.name): G.remove_edge(out_edge) G.add_edge( NNEdge(in_edge.from_node, out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) LOG.info(f'removing {node.name} that does nothing') G.remove(node) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())): node_list = self.get_node_list(G, node, FusionMatch(self._default_ktype)) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for snode in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=snode)) last_node = snode # assumption here is that the first node could have multiple inputs but definitely has only # one output input_mapping = [[ (node_list.node, idx) ] for idx in range(G.num_in_edges(node_list.node.name))] output_mapping = [(last_node, 0)] pnode = node_list.fusions_class(node_list.node.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for fnode in pnode.contained_nodes(): G.quantization.move_to_fusion(fnode, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.node.name) out_edges = G.out_edges(last_node.name) for snode in node_list.order: G.remove(snode) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]: node_list = self.get_node_list(G, conv_node) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): rnn_nodes = [ self.find_unpack(G, node) for node in G.nodes() if isinstance(node, RNNBaseParameters) and node.n_output_cells > 1 ] rnn_nodes_by_slice = self.validate_slices(G, rnn_nodes) rnn_nodes_by_slice = self.validate_multi_branch(G, rnn_nodes_by_slice) if not rnn_nodes_by_slice: return False for unpack_node, rnn_unpacks in rnn_nodes_by_slice.items(): modified_nodes = set() for rnn_unpack in rnn_unpacks: self.process_path(G, rnn_unpack, modified_nodes) # since process path will have removed all unnecessary nodes the edges will be correct here out_edges = G.out_edges(unpack_node.name) in_edges = G.in_edges(unpack_node.name) assert len(in_edges ) == 1, "expecting unpack node to have only one in edge" in_edge = in_edges[0] changes_shape = unpack_node.changes_shape if isinstance( unpack_node, StridedSliceParameters) else False LOG.info("Eliminating last cell unpack: %s", unpack_node.name) G.remove(unpack_node) # Here the strided slice can change the output shape of the RNN # so insert a reshape to do the shape change if changes_shape: reshape = ReshapeParameters( unpack_node.name + '_reshape', old_shape=Dim.unnamed(unpack_node.post_slice_shape), shape=Dim.unnamed(unpack_node.out_shape)) G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=reshape, from_idx=in_edge.from_idx)) for out_edge in out_edges: G.add_edge( NNEdge(from_node=reshape, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) if G.quantization: G.quantization[NodeId(reshape)] = G.quantization[NodeId( unpack)] else: for out_edge in out_edges: G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) if G.quantization: del G.quantization[NodeId(unpack_node)] if set_identity: self.set_identity(G) return True
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): something_changed = False for relu_node in [node for node in G.nodes(node_classes=ReluActivationParameters) if node.upper_bound == 6]: out_edges = G.out_edges(relu_node) if len(out_edges) != 1 or not isinstance(out_edges[0].to_node, MatrixMulParameters): continue mul_node = out_edges[0].to_node in_edges = G.in_edges(mul_node) if len(in_edges) != 2: continue other_edge = (set(in_edges) - {out_edges[0]}).pop() constant_node = other_edge.from_node if len(G.out_edges(constant_node)) != 1: continue if (not isinstance(constant_node, ConstantInputParameters) or not check_equals(G, constant_node, 1.0/6.0)): continue something_changed = True activation = HSigmoidActivationParameters( G.unique_name(f'{mul_node.name}_hsigmoid'), offset=0) in_edges = G.in_edges(relu_node) out_edges = G.out_edges(mul_node) nodes_to_replace = [relu_node, mul_node, constant_node] LOG.info(f'fusing {", ".join(node.name for node in nodes_to_replace)} into HSIGMOID {activation.name}') G.remove_all(nodes_to_replace) for in_edge in in_edges: G.add_edge(NNEdge.clone(in_edge, to_node=activation, to_idx=0)) for out_edge in out_edges: G.add_edge(NNEdge.clone( out_edge, from_node=activation, from_idx=0)) if G.quantization: reluqrec = G.quantization[NodeId(relu_node)] mulqrec = G.quantization[NodeId(mul_node)] del G.quantization[NodeId(constant_node)] pqrec = QRec.copy_ktype( reluqrec, in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) G.quantization[NodeId(activation)] = pqrec return something_changed
def match(self, G: GraphView, set_identity: bool = True): # get a list of all the nodes that are transposable but not transposes # Need to do this first to avoid mutating it when doing the modifications tnodes = list(filter(lambda n: isinstance(n, Transposable) and not isinstance(n, TransposeParameters), G.nodes())) has_modified_graph = False for node in tnodes: if node.transpose_in: for idx, edge in enumerate(G.in_edges(node.name)): if edge.to_idx >= len(node.transpose_in): continue trans = node.transpose_in[edge.to_idx] if trans is None: continue has_modified_graph = True in_params = TransposeParameters("%s_TIN_%s" % (node.name, idx), transpose=trans) if node.in_dims_hint and node.in_dims_hint[edge.to_idx]: in_hint = node.in_dims_hint[edge.to_idx] out_hint = apply_reverse_transpose_to_hint(in_hint, trans) in_params.in_dims_hint = [in_hint.copy()] in_params.out_dims_hint = [out_hint.copy()] node.in_dims_hint[edge.to_idx] = out_hint if G.quantization: G.quantization.copy_to_node(node, in_params) G.insert_node(in_params, edge.from_node.name, edge.to_node.name, from_idx=edge.from_idx, to_idx=edge.to_idx, edge_class=NNEdge) node.transpose_in = None if node.transpose_out: for idx, edge in enumerate(G.out_edges(node.name)): if edge.from_idx >= len(node.transpose_out): continue trans = node.transpose_out[edge.from_idx] if trans is None: continue has_modified_graph = True out_params = TransposeParameters("%s_TOUT_%s" % (node.name, idx), transpose=trans) if node.out_dims_hint: out_hint = node.out_dims_hint[edge.from_idx] in_hint = apply_reverse_transpose_to_hint(out_hint, trans) out_params.in_dims_hint = [in_hint.copy()] out_params.out_dims_hint = [out_hint.copy()] node.out_dims_hint[edge.from_idx] = in_hint if G.quantization: G.quantization.copy_to_node(node, out_params) G.insert_node(out_params, edge.from_node.name, edge.to_node.name, from_idx=edge.from_idx, to_idx=edge.to_idx, edge_class=NNEdge) node.transpose_out = None if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 else: valid_activations = VALID_ACTIVATIONS_SQ8 for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]: node_list = self.get_node_list(G, fc_node, valid_activations) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.linear, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = LinearFusionParameters( node_list.linear.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype( qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.linear.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for split_node in set( [node for node in G.nodes() if isinstance(node, SplitParameters)]): in_edges = G.in_edges(split_node.name) if len(in_edges) > 1: continue in_edge = in_edges[0] if not isinstance(in_edge.from_node, ConcatParameters): continue concat_node = in_edge.from_node if len(G.out_edges(concat_node.name)) > 1: continue if concat_node.transpose_out or split_node.transpose_in: continue if concat_node.axis != split_node.axis: continue axis = concat_node.axis split_out_sizes = [ out_shape[axis] for out_shape in split_node.out_shapes ] if len(split_out_sizes) != len(concat_node.in_dims): continue if not all(split_out_sizes[idx] == in_dim.shape[axis] for idx, in_dim in enumerate(concat_node.in_dims)): continue has_modified_graph = True LOG.info("removing unnecessary concat/split pair %s/%s", concat_node.name, split_node.name) concat_in_edges = G.indexed_in_edges(concat_node.name) split_out_edges = G.indexed_out_edges(split_node.name) G.remove(split_node) G.remove(concat_node) for idx, in_edge in enumerate(concat_in_edges): for out_edge in split_out_edges[idx]: G.add_edge( NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): # Only works for reverses connected to one RNN node reverse_nodes = set([ node for node in G.nodes() if (isinstance(node, ReverseParameters) and len(G.out_edges(node.name)) == 1 and isinstance( G.out_edges(node.name)[0].to_node, RNNBaseParameters)) ]) has_modified_graph = False for reverse_node in reverse_nodes: in_edges = G.in_edges(reverse_node.name) rnn_edge = G.out_edges(reverse_node.name)[0] if rnn_edge.to_idx != 0: LOG.warning("reverse on rnn input %s", rnn_edge.to_idx) continue assert not rnn_edge.to_node.revert, "RNN node is already reversed!" rnn_edge.to_node.revert = True LOG.info("fusing reverses into node %s", rnn_edge.to_node.name) has_modified_graph = True G.remove(reverse_node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, rnn_edge.to_node, from_idx=edge.from_idx, to_idx=rnn_edge.to_idx)) for edge in G.out_edges(rnn_edge.to_node.name): if not isinstance(edge.to_node, ReverseParameters): continue if edge.from_idx != 0: LOG.warning("reverse on rnn output %s", edge.from_idx) continue rev_edges = G.out_edges(edge.to_node.name) G.remove(edge.to_node) for rev_edge in rev_edges: G.add_edge( NNEdge(edge.from_node, rev_edge.to_node, from_idx=edge.from_idx, to_idx=rev_edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): candidates = [ node for node in G.nodes(node_classes=SplitParameters) if search_up_for_input(G, node) ] has_modified_graph = False for node in candidates: LOG.info("Insert copy on split input %s", node.name) has_modified_graph = True cnode = CopyParameters(G.unique_name(f'{node.name}_copy')) G.insert_node_at_edge(cnode, G.in_edges(node.name)[0]) if G.quantization: G.quantization.copy_qrec(node, 'in', 0, cnode) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): if not G.quantization: return rnns = [node for node in G.nodes() if isinstance( node, RNNBaseParameters)] qrecs = [G.quantization[NodeId(node)] for node in rnns] for rnn, qrec in zip(rnns, qrecs): in_idx = rnn.INPUT_NAMES.index('input') in_edge = [edge for edge in G.in_edges( rnn.name) if edge.to_idx == in_idx][0] in_q = qrec.in_qs[in_idx] propagate_qtype_up(G, in_q, in_edge) if set_identity: self.set_identity(G) return False
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: replaced = True has_modified_graph = False while replaced: replaced = False for subgraph in self.match_function(G): # TODO - Save in and out edges here since the replace function may modify the # subgraph in_edges = [ in_edge for input_node in subgraph.inputs() for in_edge in G.in_edges(input_node.name) ] out_edges = [ out_edge for output_node in subgraph.outputs() for out_edge in G.out_edges(output_node.name) ] try: replacement, edge_in_mapping, edge_out_mapping = self.replace_function( G, subgraph) if replacement is None: G.remove_fragment(subgraph) has_modified_graph = True elif isinstance(replacement, Node): # use saved in and out edges G.replace_fragment(subgraph, replacement, frag_in_edges=in_edges, frag_out_edges=out_edges, edge_in_mapping=edge_in_mapping, edge_out_mapping=edge_out_mapping) has_modified_graph = True else: raise TypeError( "unexcepted return value from replace_function") replaced = True break except DontReplaceError: pass if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): if not G.quantization: return concats = [ node for node in G.nodes() if isinstance(node, ConcatParameters) ] qrecs = [G.quantization[NodeId(node)] for node in concats] if not all(isinstance(qrec, MultQuantizationRecord) for qrec in qrecs): return for concat, qrec in zip(concats, qrecs): out_q = qrec.out_qs[0] for edge in G.in_edges(concat.name): in_q = qrec.in_qs[edge.to_idx] if in_q != out_q: propagate_qtype_up(G, out_q, edge) if set_identity: self.set_identity(G)
def match(self, G: GraphView, set_identity: bool = True): visited_edges = {} nodes_to_remove = [] has_modified_graph = False for node in G.inputs(): # check if constantinput. if is then check if positive and check max value if isinstance(node, ConstantInputParameters): if node.value is not None: if G.has_quantized_parameters: qrec = G.quantization[NodeId(node)] qtype = qrec.out_qs[0] if hasattr(qtype, 'wrapped'): qtype = qtype.wrapped val = qtype.dequantize(node.value) else: val = node.value if val.min() >= 0: status = (True, val.max()) else: status = (False, False) else: status = (False, False) for edge in G.out_edges(node.name): visited_edges[edge] = status nodes_to_remove += find_redundant_relus( G, edge.to_node, visited_edges) for node in nodes_to_remove: has_modified_graph = True # Only relus so only one in edge in_edge = G.in_edges(node.name)[0] for edge in G.out_edges(node.name): G.add_edge( NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx, to_node=edge.to_node, to_idx=edge.to_idx)) G.remove(node) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): activations = [ node for node in G.nodes() if isinstance(node, ActivationParameters) ] activations = filter( lambda n: not isinstance( G.in_edges(n.name)[0].from_node, VALID_FUSIONS), activations) can_be_moved = [] for activation in activations: try: edges = list(self.find_home_for_activation(G, activation)) LOG.info("Activation %s can be moved", activation.name) can_be_moved.append({'activation': activation, 'edges': edges}) except LocationNotFoundError: LOG.info("Activation %s cannot be moved", activation.name) for move in can_be_moved: self.move_activation(G, move['activation'], move['edges']) if set_identity: self.set_identity(G)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=SplitParameters): same_op_edges = self.moveable_same_operation_edges(G, node) if not same_op_edges: continue has_modified_graph = True in_edges = G.in_edges(node.name) assert len(in_edges) == 1 # sort by name to ensure that operation is repeatable same_op_edges.sort(key=lambda x: x.to_node.name) keep_node = same_op_edges[0].to_node LOG.info('split node %s has duplicate operations on its out edges', node.name) LOG.info('moving %s before split node %s', keep_node.name, node.name) for edge in G.out_edges(node.name): node_out_edges = G.out_edges(edge.to_node.name) G.remove(edge.to_node) if edge.to_node != keep_node: LOG.info('deleting duplicate node %s', edge.to_node.name) if G.quantization: nid = NodeId(edge.to_node) if nid in G.quantization: del G.quantization[nid] for out_edge in node_out_edges: G.add_edge( NNEdge(from_node=node, from_idx=edge.from_idx, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) G.insert_node_at_edge(keep_node, in_edges[0], edge_class=NNEdge) if G.quantization: quantizer = NewQuantizer.from_quantized_graph(G) quantizer.quantize() if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False nodes_to_remove = [] for node in G.nodes(node_classes=CopyParameters): out_edges = G.out_edges(node) if len(out_edges) > 1: continue if (search_down( G, out_edges[0], (OutputParameters, InputParameters, ConstantInputParameters, SplitParameters, ConcatParameters), can_pass=(ReshapeParameters, NoOPParameters), can_pass_fn=lambda G, node: isinstance( node, TransposeParameters) and node.does_nothing, follow_multi=True) and search_up( G, G.in_edges(node)[0], (InputParameters, OutputParameters, ConstantInputParameters, SplitParameters, ConcatParameters), can_pass=(ReshapeParameters, NoOPParameters), can_pass_fn=lambda G, node: isinstance( node, TransposeParameters) and node.does_nothing, follow_multi=True)): continue nodes_to_remove.append(node) for node in nodes_to_remove: LOG.info("remove redundant copy %s", node.name) has_modified_graph = True G.remove_and_reconnect(node, edge_class=NNEdge) if G.quantization: nid = NodeId(node) if nid in G.quantization: del G.quantization[nid] if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): target_nodes = [node for node in G.nodes( ) if self.execute_tests(G, self.ValidNodes, node)] target_nodes = filter(lambda n: not isinstance( G.in_edges(n.name)[0].from_node, self.ValidFusions), target_nodes) can_be_moved = [] has_modified_graph = False for node in target_nodes: try: edges = list(self.find_home_for_node(G, node)) LOG.info("Node %s can be moved", node.name) can_be_moved.append({'node': node, 'edges': edges}) except LocationNotFoundError: LOG.info("Node %s cannot be moved", node.name) for move in can_be_moved: has_modified_graph = True self.move_node(G, move['node'], move['edges']) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): visited_edges = {} nodes_to_remove = [] has_modified_graph = False for node in G.inputs(): # check if constantinput. if is then check if positive and check max value if isinstance(node, ConstantInputParameters): if node.value is not None: val = node.dqvalue if np.min(val) >= 0: status = (True, np.max(val)) else: status = (False, False) else: status = (False, False) else: status = (False, False) for edge in G.out_edges(node.name): visited_edges[edge] = status nodes_to_remove += find_redundant_relus( G, edge.to_node, visited_edges) for node in nodes_to_remove: has_modified_graph = True # Only relus so only one in edge LOG.info("removing redundant relu %s", node.name) in_edge = G.in_edges(node.name)[0] out_edges = G.out_edges(node.name) G.remove(node) for edge in out_edges: G.add_edge(NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx, to_node=edge.to_node, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 else: valid_activations = VALID_ACTIVATIONS_SQ8 for conv_node in [ params for params in G.nodes() if isinstance(params, Conv2DParameters) ]: node_list = self.get_node_list(G, conv_node, valid_activations) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': # NOTE: This is only for old POW2 kernels - SQ8 can handle this if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, in_dims=deepcopy(node_list.conv.in_dims), out_dims=deepcopy(node_list.order[-1].out_dims), input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: # TODO - stats prec = QRec.copy_ktype(qrecs[0], in_qs=deepcopy(qrecs[0].in_qs), out_qs=deepcopy(qrecs[-1].out_qs)) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True) -> bool: has_modified_graph = False for pad_params in [ pad for pad in G.nodes() if isinstance(pad, PadParameters) ]: pad_in_edges = G.in_edges(pad_params.name) pad_out_edges = G.out_edges(pad_params.name) dont_delete = False for pad_out_edge in pad_out_edges: filter_like_node, is_1d = self.find_conv( G, pad_out_edge.to_node) if not filter_like_node: dont_delete = True continue if not filter_like_node.in_dims_hint or not filter_like_node.in_dims_hint[ 0]: raise ValueError( f"filter {filter_like_node.name} doesn't have a input hint" ) in_hint = filter_like_node.in_dims_hint[0] if is_1d: if len(pad_params.padding) != 2: LOG.warning( "pad node %s is applied to 1d convolution but has length %s", pad_params.name, len(pad_params.padding)) dont_delete = True continue expanded_padding = [ pad_params.padding[0], (0, 0), pad_params.padding[1] ] else: if len(pad_params.padding) != 3: LOG.warning( "pad node %s is applied to 2d convolution but has length %s", pad_params.name, len(pad_params.padding)) dont_delete = True continue expanded_padding = pad_params.padding hinted_pad = { in_hint[idx]: pad for idx, pad in enumerate(expanded_padding) if sum(pad) > 0 } key_set = set(hinted_pad.keys()) key_set -= set(['h', 'w']) if len(key_set) > 0: dont_delete = True LOG.error( "node %s has padding on axes %s and cannot be fused with filter %s", pad_params.name, key_set, filter_like_node.name) continue if any(pval != 0 for val in pad_params.pad_vals for pval in val): dont_delete = True LOG.error( "node %s has non zero pad values and cannot be fused with filter %s", pad_params.name, filter_like_node.name) continue LOG.info("adding padding from: %s to %s filter: %s", pad_params.name, is_1d and "1D" or "2D", filter_like_node.name) for key in ['h', 'w']: if key not in hinted_pad: hinted_pad[key] = (0, 0) filter_like_node.padding = PadDim(*(list(hinted_pad['h']) + list(hinted_pad['w']))) filter_like_node.pad_type = "zero" has_modified_graph = True G.remove_edge(pad_out_edge) if is_1d: reshape_node = pad_out_edge.to_node reshape_node.old_shape = self.remove_padding( reshape_node.old_shape, pad_params.padding) reshape_node.shape = self.remove_padding( reshape_node.shape, expanded_padding) for in_edge in pad_in_edges: G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=pad_out_edge.to_node, from_idx=in_edge.from_idx, to_idx=pad_out_edge.to_idx)) if not dont_delete: G.remove(pad_params) if G.quantization: G.quantization.remove_node(pad_params) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for pad_node in [ params for params in G.nodes() if isinstance(params, PadParameters) ]: node_list = self.get_node_list(G, pad_node) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx subgraph.add_edge( NNEdge(from_node=node_list.pad, to_node=node_list.add, to_idx=padded_input_idx)) last_node = node_list.add node_list.add.force_quantized_index = 0 if node_list.active: subgraph.add_edge( NNEdge(from_node=node_list.add, to_node=node_list.active)) last_node = node_list.active if padded_input_idx == 0: input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]] else: input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]] output_mapping = [(last_node, 0)] pnode = PaddedAddFusionParameters( "PADDED_" + node_list.add.name, fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec if padded_input_idx == 0: in_edges = G.in_edges(node_list.pad.name) + G.indexed_in_edges( node_list.add.name)[1::] else: in_edges = G.indexed_in_edges( node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False gathers_by_origin = {} for gather in [ node for node in G.nodes() if isinstance(node, GatherParameters) ]: in_edge = G.in_edges(gather.name)[0] group = gathers_by_origin.setdefault( (in_edge.from_node, in_edge.from_idx), []) group.append(gather) for in_edge, gathers in gathers_by_origin.items(): # This is too difficult to handle if there are multiple slices axis = gathers[0].axis if not all(gather.axis == axis and len(gather.indices.shape) <= 1 for gather in gathers[1::]): continue # sort all the indices gathers = sorted(gathers, key=lambda x: x.indices if len(x.indices.shape) == 0 else x.indices[0]) indices = [ elem for gather in gathers for elem in ([int(gather.indices)] if len(gather.indices.shape) == 0 else list(gather.indices)) ] # All the indices must be independant and sum to the out dim (this could be relaxed but # then needs to handle gaps) in_shape = in_edge[0].out_dims[in_edge[1]].shape in_shape_without_axis = in_shape[:axis:] + in_shape[axis + 1::] if len(set(indices)) != len(indices) and len( set(indices)) == in_shape[axis]: continue # good for a split LOG.info("gathers from %s[%s] converted to a split", in_edge[0].name, in_edge[1]) splits = [] shapes = [] out_edges = [] for gather in gathers: splits.append( [tuple([int(gather.indices), int(gather.indices) + 1, 1])]) shapes.append(in_shape_without_axis) out_edges.append(G.out_edges(gather.name)) G.remove(gather) params = SplitParameters("%s_split" % in_edge[0].name, act_slices=splits, out_shapes=shapes, axis=axis) if axis != 0: trans = [axis] + list(range(0, axis)) + list( range(axis, len(in_shape))) params.transpose_out = [[ trans.index(idx) for idx in range(len(trans)) ]] params.transpose_in = [trans] for idx, edges in enumerate(out_edges): for edge in edges: G.add_edge( NNEdge(from_node=params, to_node=edge.to_node, from_idx=idx, to_idx=edge.to_idx)) G.add_edge( NNEdge(from_node=in_edge[0], to_node=params, from_idx=in_edge[1])) has_modified_graph = True if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 valid_activations_wo_pool = VALID_ACTIVATIONS_POW2_WO_POOL else: valid_activations = VALID_ACTIVATIONS_SQ8 valid_activations_wo_pool = VALID_ACTIVATIONS_SQ8_WO_POOL for pool_node in G.nodes(node_classes=(PoolingParameters, GlobalPoolingParameters)): node_list = self.get_node_list(G, pool_node, valid_activations, valid_activations_wo_pool) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.pool, 0)]] output_mapping = [(last_node, 0)] pnode = ActivationFusion(node_list.pool.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) if isinstance(node, GlobalPoolingParameters): # Global pooling fused with activations need to have only the activation scale G.quantization[NodeId(pnode, node)].out_qs[0] = deepcopy( G.quantization[NodeId( pnode, node)].in_qs[0]) G.quantization[NodeId( pnode, node)].out_qs[0].dtype = np.int32 G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.pool.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False slices_by_origin = {} for slice_node in [ node for node in G.nodes() if isinstance(node, StridedSliceParameters) ]: in_edge = G.in_edges(slice_node.name)[0] group = slices_by_origin.setdefault( (in_edge.from_node, in_edge.from_idx), []) group.append(slice_node) for in_edge, slice_nodes in slices_by_origin.items(): slices = list(zip(*[node.act_slice for node in slice_nodes])) if len(slice_nodes) == 1: self.slice_to_split(G, slice_nodes, slices) continue # strides must be one if any(sl[2] != 1 for sl_axis in slices for sl in sl_axis): continue diff_axes = list([ idx for idx, elems in enumerate(slices) if not all(elems[0] == elem for elem in elems[1::]) ]) not_diff_axes = [ idx for idx in range(len(slices)) if idx not in diff_axes ] diff_slices = [ sl for idx, sl in enumerate(slices) if idx in diff_axes ] axis_lengths = in_edge[0].out_dims[in_edge[1]].shape if not_diff_axes and min(not_diff_axes) < max(diff_axes): transpose_from = tuple(range(len(slices))) transpose_to = tuple(diff_axes + not_diff_axes) axis_lengths = [axis_lengths[idx] for idx in transpose_to] else: transpose_from = transpose_to = None diff_axis_lengths = axis_lengths[0:len(diff_axes):] diff_slices = combine_slices(diff_axis_lengths, diff_slices, slice_nodes) if diff_slices is None: continue if len(diff_axes) > 1: reshape_from = axis_lengths reshape_to = [np.prod(diff_axis_lengths)] + \ axis_lengths[len(diff_axes)::] else: reshape_from = None reshape_to = slice_nodes[0].in_dims[0].shape if transpose_from: reshape_to = [reshape_to[idx] for idx in transpose_to] sizes, shapes, sorted_nodes = slices_to_sizes( diff_slices, axis_lengths[len(diff_axes)::]) name_prefix = sorted_nodes[0].name in_edge = G.in_edges(sorted_nodes[0].name)[0] in_node = in_edge.from_node in_idx = in_edge.from_idx if transpose_from: params = TransposeParameters(G.unique_name(name_prefix + '_tin'), transpose=transpose_to) G.add_edge( NNEdge(from_node=in_node, to_node=params, from_idx=in_idx)) in_node = params in_idx = 0 if reshape_from: params = ReshapeParameters(G.unique_name(name_prefix + '_reshape'), old_shape=Dim.unnamed(reshape_from), shape=Dim.unnamed(reshape_to)) G.add_edge( NNEdge(from_node=in_node, to_node=params, from_idx=in_idx)) in_node = params in_idx = 0 act_slices, out_shapes, axis = SplitParameters.get_splits( reshape_to, 0, splits=sizes) split_node = SplitParameters(G.unique_name(name_prefix + '_split'), act_slices=act_slices, out_shapes=out_shapes, axis=axis) G.add_edge( NNEdge(from_node=in_node, from_idx=in_idx, to_node=split_node)) sub_names = [] for idx, node in enumerate(sorted_nodes): sub_names.append(node.name) out_edges = G.out_edges(node.name) G.remove(node) for out_edge in out_edges: params = split_node out_idx = idx if reshape_from: from_node = params params = ReshapeParameters( G.unique_name(name_prefix + f'_reshape{idx}'), shape=Dim.unnamed(shapes[idx])) G.add_edge( NNEdge(from_node=from_node, to_node=params, from_idx=out_idx)) out_idx = 0 if transpose_from: from_node = params params = TransposeParameters( G.unique_name(name_prefix + f'_tout{idx}'), transpose=reverse_transpose(transpose_to)) G.add_edge( NNEdge(from_node=from_node, to_node=params, from_idx=out_idx)) out_idx = 0 G.add_edge( NNEdge(from_node=params, to_node=out_edge.to_node, from_idx=out_idx, to_idx=out_edge.to_idx)) if G.quantization: G.add_dimensions() quantizer = NewQuantizer.from_quantized_graph(G) quantizer.quantize() RemoveUnnecessaryQuantizeOperators().match(G) LOG.info( f'replaced slice nodes {",".join(sub_names)} with split node {split_node.name}' ) has_modified_graph = True if set_identity: self.set_identity(G) return has_modified_graph