def _match(self, G: GraphView, set_identity: bool = True, **kwargs): fragment = GraphMatcher( match_function=lambda state, frag: (frag, state['match'])) fragment.add_node(MatScaleNodeMatch()) has_modified_graph = False for frag, match in fragment.match_graph(G): match_edges = [ G.indexed_in_edges(node.name)[idx] for node, idx in match['inputs'] ] matched_node = list(frag.nodes())[0] out_edges = G.out_edges(matched_node.name) has_modified_graph = True G.remove(matched_node) fnode = MatScaleFusionParameters( "{}_fusion".format(matched_node.name), fusion_type=match['type'], subgraph=frag, input_mapping=[[(matched_node, 0)], [(matched_node, 1)]]) G.add_node(fnode) for idx, edge in enumerate(match_edges): edge.to_node = fnode edge.to_idx = idx G.add_edge(edge) for edge in out_edges: edge.from_node = fnode G.add_edge(edge) if set_identity: self.set_identity(G) return has_modified_graph
def reverse_matmul(G: GraphView, params): # reverse edges in_edges = G.indexed_in_edges(params.name) for edge in in_edges[0:2:]: G.remove_edge(edge) other_idx = 1 for edge in in_edges[0:2:]: G.add_edge( NNEdge(from_node=edge.from_node, to_node=params, from_idx=edge.from_idx, to_idx=other_idx)) other_idx = 1 - other_idx nid = NodeId(params) if G.quantization and nid in G.quantization: qrec = G.quantization[nid] # swap qrecs qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0] # add transposes in_nodes = [] for idx in range(2): tin_params = TransposeParameters( G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0)) in_nodes.append(tin_params) G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge) tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"), transpose=(1, 0)) G.insert_node_after(params, tout_params) return in_nodes, tout_params
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = False candidates = [node for node in G.nodes() if len(G.indexed_out_edges(node.name)) == 1 and len(G.out_edges(node.name)) > 1] while candidates: node = candidates.pop(0) strings = self.explore(G, [node]) if not strings: continue modified_graph = True primary = strings.pop(0) for pnode in primary: if pnode in candidates: candidates.remove(pnode) out_edges = [] for other in strings: out_edges.extend(G.out_edges(other[-1].name)) for other_node in other: if other_node in candidates: candidates.remove(other_node) G.remove(other_node) nid = NodeId(other_node) if G.quantization and nid in G.quantization: del G.quantization[nid] LOG.info( f'removed duplicates from {primary[0].name} {",".join(node.name for node in other)}') pend = primary[-1] for edge in out_edges: G.add_edge( NNEdge(from_node=pend, to_node=edge.to_node, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=MatMulOpParameters): in_edges = [edge for edge in G.indexed_in_edges(node.name)] trans_node = in_edges[1].from_node if not isinstance(trans_node, TransposeParameters): continue if isinstance(node, MatMulTransposedParameters): new_node = MatMulOpParameters(node.name) else: new_node = MatMulTransposedParameters(node.name) in_trans_edge = [ edge for edge in G.indexed_in_edges(trans_node.name) ][0] G.replace_node(node.name, new_node) G.remove(trans_node) G.add_edge( NNEdge(in_trans_edge.from_node, new_node, from_idx=in_trans_edge.from_idx, to_idx=1)) has_modified_graph = True if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): split_nodes = [ node for node in G.nodes() if isinstance(node, SplitParameters) ] has_modified_graph = False for node in split_nodes: # traverse reshapes or transposes that do nothing - check gen # find edges connected to concats res = self.find_split_concat(G, node) if res is None: continue # TODO(martin) - group edges that have adjacent inputs and outputs if G.quantization: qrec = G.quantization[NodeId(node)] for idx, bundle in enumerate(res): if not bundle: continue has_modified_graph = True copy_node = CopyParameters("%s_copy_%s" % (node.name, idx)) for edge_set in bundle: first_edge = edge_set[0] G.remove_edge(first_edge) G.add_edge( NNEdge(copy_node, first_edge.to_node, to_idx=first_edge.to_idx)) G.add_edge(NNEdge(node, copy_node, from_idx=idx)) if G.quantization: G.quantization[NodeId(copy_node)] = qrec.__class__( in_qs=deepcopy(qrec.out_qs[idx]), out_qs=deepcopy(qrec.out_qs[idx])) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for node in [ node for node in G.nodes() if self.node_does_nothing(G, node) ]: has_modified_graph = True in_edge = G.in_edges(node.name)[0] G.remove_edge(in_edge) for out_edge in G.out_edges(node.name): G.remove_edge(out_edge) G.add_edge( NNEdge(in_edge.from_node, out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) LOG.info(f'removing {node.name} that does nothing') G.remove(node) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified = False for node in G.nodes(node_classes=ConstantInputParameters): out_edges = G.out_edges(node.name) if len(out_edges) <= 1: continue has_modified = True LOG.info( 'node %s has more than one out edge and will be duplicated', node.name) idx = 1 for out_edge in out_edges[1::]: new_constant = ConstantInputParameters(f'{node.name}_{idx}', dims=Dim.unnamed( node.dims.shape), value=node.value.copy()) G.remove_edge(out_edge) G.add_edge( NNEdge(from_node=new_constant, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) idx += 1 if set_identity: self.set_identity(G) return has_modified
def match_function(self, G: GraphView): sub = GraphView() sub.add_node(MatchNode('0', matcher=lambda node:\ isinstance(node, PadParameters))) sub.add_node(MatchNode('1', matcher=lambda node:\ isinstance(node, FilterLikeParameters) and\ self.has_no_padding(node))) sub.add_edge(Edge('0', '1')) return G.match_fragment(sub)
def match_function(self, G: GraphView): sub = GraphView() sub.add_node(MatchNode('0', matcher=lambda node: isinstance(node, FcParameters) and self.valid_linear(node))) sub.add_node(MatchNode('1', matcher=lambda node: isinstance(node, ActivationParameters) and self.valid_activation(node))) sub.add_edge(Edge('0', '1')) return G.match_fragment(sub)
def split_down_from(cur_g, node, res_g=None): """ split cur_g into 2 graphs. Everything from node down and the rest """ if res_g is None: res_g = GraphView() out_edges = cur_g.out_edges(node.name) cur_g.remove(node) if node not in res_g.nodes(): res_g.add_node(node) for edge in out_edges: res_g.add_edge(edge.clone()) split_down_from(cur_g, edge.to_node, res_g=res_g) return res_g
def match_function(self, G: GraphView): sub = GraphView() sub.add_node(MatchNode('0', matcher=lambda node:\ isinstance(node, FilterParameters))) sub.add_node(MatchNode('1', matcher=lambda node:\ isinstance(node, MatrixAddParameters))) sub.add_node(MatchNode('2', matcher=lambda node:\ isinstance(node, ConstantInputParameters))) sub.add_edge(Edge('0', '1', to_idx=0)) sub.add_edge(Edge('2', '1', to_idx=1)) return G.match_fragment(sub)
def construct_subgraph(G, nodes): """ construct a subgraph from nodes """ sub_g = GraphView() while nodes: node = nodes.pop(0) if node not in sub_g.nodes(): sub_g.add_node(node) for edge in G.out_edges(node.name): if edge.to_node in nodes: sub_g.add_edge(edge.clone()) for edge in G.in_edges(node.name): if edge.from_node in nodes: sub_g.add_edge(edge.clone()) return sub_g
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for node_set in self.find_sets(G): has_modified_graph = True in_edges, out_edges, internal_edges = group_edges(G, node_set) frag = GraphView() for edge in internal_edges: frag.add_edge(edge) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] in_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in in_edges ] out_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in out_edges ] out_mapping = list(out_edges.keys()) constant_inputs = [ node_edge_idx[0] for node_edge_idx in in_edges if isinstance(node_edge_idx[0], ConstantInputParameters) ] LOG.info('matched expression - creating expression %s', self._expr_num) expr = ExpressionFusionParameters(f"expr_{self._expr_num}", subgraph=frag, input_mapping=in_mapping, output_mapping=out_mapping, in_dims=in_dims, out_dims=out_dims, constant_inputs=constant_inputs) in_edge_mapping = list(in_edges.keys()) out_edge_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()] G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=in_edge_mapping, edge_out_mapping=out_edge_mapping) self._expr_num += 1 if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for split_node in set( [node for node in G.nodes() if isinstance(node, SplitParameters)]): in_edges = G.in_edges(split_node.name) if len(in_edges) > 1: continue in_edge = in_edges[0] if not isinstance(in_edge.from_node, ConcatParameters): continue concat_node = in_edge.from_node if len(G.out_edges(concat_node.name)) > 1: continue if concat_node.transpose_out or split_node.transpose_in: continue if concat_node.axis != split_node.axis: continue axis = concat_node.axis split_out_sizes = [ out_shape[axis] for out_shape in split_node.out_shapes ] if len(split_out_sizes) != len(concat_node.in_dims): continue if not all(split_out_sizes[idx] == in_dim.shape[axis] for idx, in_dim in enumerate(concat_node.in_dims)): continue has_modified_graph = True LOG.info("removing unnecessary concat/split pair %s/%s", concat_node.name, split_node.name) concat_in_edges = G.indexed_in_edges(concat_node.name) split_out_edges = G.indexed_out_edges(split_node.name) G.remove(split_node) G.remove(concat_node) for idx, in_edge in enumerate(concat_in_edges): for out_edge in split_out_edges[idx]: G.add_edge( NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): # Only works for reverses connected to one RNN node reverse_nodes = set([ node for node in G.nodes() if (isinstance(node, ReverseParameters) and len(G.out_edges(node.name)) == 1 and isinstance( G.out_edges(node.name)[0].to_node, RNNBaseParameters)) ]) has_modified_graph = False for reverse_node in reverse_nodes: in_edges = G.in_edges(reverse_node.name) rnn_edge = G.out_edges(reverse_node.name)[0] if rnn_edge.to_idx != 0: LOG.warning("reverse on rnn input %s", rnn_edge.to_idx) continue assert not rnn_edge.to_node.revert, "RNN node is already reversed!" rnn_edge.to_node.revert = True LOG.info("fusing reverses into node %s", rnn_edge.to_node.name) has_modified_graph = True G.remove(reverse_node) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, rnn_edge.to_node, from_idx=edge.from_idx, to_idx=rnn_edge.to_idx)) for edge in G.out_edges(rnn_edge.to_node.name): if not isinstance(edge.to_node, ReverseParameters): continue if edge.from_idx != 0: LOG.warning("reverse on rnn output %s", edge.from_idx) continue rev_edges = G.out_edges(edge.to_node.name) G.remove(edge.to_node) for rev_edge in rev_edges: G.add_edge( NNEdge(edge.from_node, rev_edge.to_node, from_idx=edge.from_idx, to_idx=rev_edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False # collect connected node sets node_sets = group_nodes(G, [ node for node in G.nodes() if isinstance(node, FUSE_NODES) or ( isinstance(node, ConstantInputParameters) and node.out_dims[0].size() == 1) ]) # remove sets that are only ConstantInputs node_sets = [ node_set for node_set in node_sets if not all( isinstance(node, ConstantInputParameters) for node in node_set) ] for node_set in node_sets: has_modified_graph = True in_edges, out_edges, internal_edges = group_edges(G, node_set) frag = GraphView() for edge in internal_edges: frag.add_edge(edge) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] out_mapping = list(out_edges.keys()) constant_inputs = [ isinstance(node_edge_idx[0], ConstantInputParameters) for node_edge_idx in in_edges ] expr = ExpressionFusionParameters("expr_%s" % self._expr_num, subgraph=frag, input_mapping=in_mapping, output_mapping=out_mapping, constant_inputs=constant_inputs) G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=sorted(list(in_edges.keys()), key=lambda x: x[1]), edge_out_mapping=[[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()]) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): something_changed = False for relu_node in [node for node in G.nodes(node_classes=ReluActivationParameters) if node.upper_bound == 6]: out_edges = G.out_edges(relu_node) if len(out_edges) != 1 or not isinstance(out_edges[0].to_node, MatrixMulParameters): continue mul_node = out_edges[0].to_node in_edges = G.in_edges(mul_node) if len(in_edges) != 2: continue other_edge = (set(in_edges) - {out_edges[0]}).pop() constant_node = other_edge.from_node if len(G.out_edges(constant_node)) != 1: continue if (not isinstance(constant_node, ConstantInputParameters) or not check_equals(G, constant_node, 1.0/6.0)): continue something_changed = True activation = HSigmoidActivationParameters( G.unique_name(f'{mul_node.name}_hsigmoid'), offset=0) in_edges = G.in_edges(relu_node) out_edges = G.out_edges(mul_node) nodes_to_replace = [relu_node, mul_node, constant_node] LOG.info(f'fusing {", ".join(node.name for node in nodes_to_replace)} into HSIGMOID {activation.name}') G.remove_all(nodes_to_replace) for in_edge in in_edges: G.add_edge(NNEdge.clone(in_edge, to_node=activation, to_idx=0)) for out_edge in out_edges: G.add_edge(NNEdge.clone( out_edge, from_node=activation, from_idx=0)) if G.quantization: reluqrec = G.quantization[NodeId(relu_node)] mulqrec = G.quantization[NodeId(mul_node)] del G.quantization[NodeId(constant_node)] pqrec = QRec.copy_ktype( reluqrec, in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) G.quantization[NodeId(activation)] = pqrec return something_changed
def move_constant(cls, G: GraphView, params, in_qs): # looks for a constant on one of the inputs # if there is one we can scale by the second dimension of the second # tensor. If the constant is on the first tensor then move to the second # and transpose the operation in_edges = G.indexed_in_edges(params.name) in1_node = in_edges[0].from_node in2_node = in_edges[1].from_node if isinstance(in2_node, ConstantInputParameters): return in2_node, in_qs elif isinstance(in1_node, ConstantInputParameters): if len(params.in_dims) > 2: # check if the bias has the correct length to move constant # it must have a length equal to the second tensors second dimension after transpose bias_size = params.in_dims[2].size() in1_shape = params.in_dims[0].shape if in1_shape[1] != bias_size: return None, in_qs for edge in in_edges[:2:]: G.remove_edge(edge) to_idx = 1 # swap edges to move constant onto input 2 for edge in in_edges[:2:]: new_edge = NNEdge(from_node=edge.from_node, to_node=edge.to_node, from_idx=edge.from_idx, to_idx=to_idx) G.add_edge(new_edge) to_idx = 1 - to_idx # use A.B = (BT.AT)T identity tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'), transpose=(1, 0)) tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'), transpose=(1, 0)) tout = TransposeParameters(G.unique_name(f'{params.name}_tout'), transpose=(1, 0)) G.insert_node_before(tin1, params) G.insert_node_before(tin2, params, to_idx=1) G.insert_node_after(params, tout) LOG.warning('transposes inserted on %s - rerun adjust', params.name) return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::] else: return None, in_qs
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): fac = MatScalePairMatchFactory() has_modified_graph = False for frag, match in fac.get_matcher().match_graph(G): match_edges = [ G.indexed_in_edges(node.name)[idx] for node, idx in match ] first_node = frag.inputs()[0] last_node = frag.outputs()[0] out_edges = G.out_edges(last_node.name) for node in frag.nodes(): G.remove(node) input_mapping = MatScaleFusionParameters.get_mapping_from_edges( match_edges) fnode = MatScaleFusionParameters( "{}_{}_fusion".format(first_node.name, last_node.name), fusion_type="vec_scalar", subgraph=frag, input_mapping=MatScaleFusionParameters.convert_input_mapping( input_mapping)) has_modified_graph = True G.add_node(fnode) fnode.in_dims_hint = [None] * 3 for idx, edge in enumerate(match_edges): new_edge = edge.clone( to_node=fnode, to_idx=list(input_mapping[edge.to_node].keys())[0]) if new_edge.from_node.out_dims_hint: fnode.in_dims_hint[idx] = new_edge.from_node.out_dims_hint[ edge.from_idx] G.add_edge(new_edge) for edge in out_edges: new_edge = edge.clone(from_node=fnode) G.add_edge(new_edge) if set_identity: self.set_identity(G) return has_modified_graph
def find_concats_up(G, concat, subgraph: GraphView = None): # Produces a subgraph of concats operating on axis 0 separated by copys or reshapes. # the output node will be the final concat. the input nodes will be all the inputs # to a condensed concat that can replace this subgraph. if subgraph is None: subgraph = GraphView() edge_path = [] for edge in G.indexed_in_edges(concat.name): edge_path = traverse_to_concat(G, edge, subgraph) if edge_path: for inter_edge in edge_path: subgraph.add_edge(inter_edge) else: subgraph.add_edge( NNEdge(from_node=DummyInput( f"{edge.from_node.name}_{edge.from_idx}", edge), to_node=edge.to_node, to_idx=edge.to_idx)) return subgraph
def match(self, G: GraphView, set_identity: bool = True): visited_edges = {} nodes_to_remove = [] has_modified_graph = False for node in G.inputs(): # check if constantinput. if is then check if positive and check max value if isinstance(node, ConstantInputParameters): if node.value is not None: if G.has_quantized_parameters: qrec = G.quantization[NodeId(node)] qtype = qrec.out_qs[0] if hasattr(qtype, 'wrapped'): qtype = qtype.wrapped val = qtype.dequantize(node.value) else: val = node.value if val.min() >= 0: status = (True, val.max()) else: status = (False, False) else: status = (False, False) for edge in G.out_edges(node.name): visited_edges[edge] = status nodes_to_remove += find_redundant_relus( G, edge.to_node, visited_edges) for node in nodes_to_remove: has_modified_graph = True # Only relus so only one in edge in_edge = G.in_edges(node.name)[0] for edge in G.out_edges(node.name): G.add_edge( NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx, to_node=edge.to_node, to_idx=edge.to_idx)) G.remove(node) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = True while modified_graph: modified_graph = False for reshape in G.nodes(node_classes=(ReshapeParameters, )): if not reshape.has_transpose and reshape.shape.shape == reshape.old_shape.shape: modified_graph = True LOG.info('removing reshape that does nothing %s', reshape.name) G.remove_and_reconnect(reshape, edge_class=NNEdge) nid = NodeId(reshape) if G.quantization and nid in G.quantization: del G.quantization[nid] res = None for reshape in G.nodes(node_classes=(ReshapeParameters, )): res = self.validate_reshape(G, reshape) if res: LOG.info('unnecessary reshape found after %s', reshape.name) modified_graph = True (reshape, candidates, out_shape) = res for candidate in candidates: LOG.info( 'removing unnecessary reshape or transpose %s', candidate.name) edges = G.out_edges(candidate.name) G.remove(candidate) nid = NodeId(candidate) if G.quantization and nid in G.quantization: del G.quantization[nid] for edge in edges: G.add_edge( NNEdge(from_node=reshape, to_node=edge.to_node, to_idx=edge.to_idx)) reshape.shape = Dim.unnamed(out_shape) break if set_identity: self.set_identity(G) return modified_graph
def match_function(self, G: GraphView): sub = GraphView() sub.add_node( MatchNode( '0', matcher=lambda node: isinstance(node, ReluActivationParameters ) and node.upper_bound == 6)) sub.add_node( MatchNode( '1', matcher=lambda node: isinstance(node, MatrixMulParameters))) sub.add_node( MatchNode( '2', matcher=lambda node: isinstance(node, ConstantInputParameters) and check_equals(G, node, 1.0 / 6.0))) sub.add_edge(Edge('0', '1', to_idx=0)) sub.add_edge(Edge('2', '1', to_idx=1)) return G.match_fragment(sub)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())): node_list = self.get_node_list(G, node, FusionMatch(self._default_ktype)) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for snode in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=snode)) last_node = snode # assumption here is that the first node could have multiple inputs but definitely has only # one output input_mapping = [[ (node_list.node, idx) ] for idx in range(G.num_in_edges(node_list.node.name))] output_mapping = [(last_node, 0)] pnode = node_list.fusions_class(node_list.node.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for fnode in pnode.contained_nodes(): G.quantization.move_to_fusion(fnode, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.node.name) out_edges = G.out_edges(last_node.name) for snode in node_list.order: G.remove(snode) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=SplitParameters): same_op_edges = self.moveable_same_operation_edges(G, node) if not same_op_edges: continue has_modified_graph = True in_edges = G.in_edges(node.name) assert len(in_edges) == 1 # sort by name to ensure that operation is repeatable same_op_edges.sort(key=lambda x: x.to_node.name) keep_node = same_op_edges[0].to_node LOG.info('split node %s has duplicate operations on its out edges', node.name) LOG.info('moving %s before split node %s', keep_node.name, node.name) for edge in G.out_edges(node.name): node_out_edges = G.out_edges(edge.to_node.name) G.remove(edge.to_node) if edge.to_node != keep_node: LOG.info('deleting duplicate node %s', edge.to_node.name) if G.quantization: nid = NodeId(edge.to_node) if nid in G.quantization: del G.quantization[nid] for out_edge in node_out_edges: G.add_edge( NNEdge(from_node=node, from_idx=edge.from_idx, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) G.insert_node_at_edge(keep_node, in_edges[0], edge_class=NNEdge) if G.quantization: quantizer = NewQuantizer.from_quantized_graph(G) quantizer.quantize() if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]: node_list = self.get_node_list(G, conv_node) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): visited_edges = {} nodes_to_remove = [] has_modified_graph = False for node in G.inputs(): # check if constantinput. if is then check if positive and check max value if isinstance(node, ConstantInputParameters): if node.value is not None: val = node.dqvalue if np.min(val) >= 0: status = (True, np.max(val)) else: status = (False, False) else: status = (False, False) else: status = (False, False) for edge in G.out_edges(node.name): visited_edges[edge] = status nodes_to_remove += find_redundant_relus( G, edge.to_node, visited_edges) for node in nodes_to_remove: has_modified_graph = True # Only relus so only one in edge LOG.info("removing redundant relu %s", node.name) in_edge = G.in_edges(node.name)[0] out_edges = G.out_edges(node.name) G.remove(node) for edge in out_edges: G.add_edge(NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx, to_node=edge.to_node, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): rnn_nodes = [ self.find_unpack(G, node) for node in G.nodes() if isinstance(node, RNNBaseParameters) and node.n_output_cells > 1 ] rnn_nodes_by_slice = self.validate_slices(G, rnn_nodes) rnn_nodes_by_slice = self.validate_multi_branch(G, rnn_nodes_by_slice) if not rnn_nodes_by_slice: return False for unpack_node, rnn_unpacks in rnn_nodes_by_slice.items(): modified_nodes = set() for rnn_unpack in rnn_unpacks: self.process_path(G, rnn_unpack, modified_nodes) # since process path will have removed all unnecessary nodes the edges will be correct here out_edges = G.out_edges(unpack_node.name) in_edges = G.in_edges(unpack_node.name) assert len(in_edges ) == 1, "expecting unpack node to have only one in edge" in_edge = in_edges[0] changes_shape = unpack_node.changes_shape if isinstance( unpack_node, StridedSliceParameters) else False LOG.info("Eliminating last cell unpack: %s", unpack_node.name) G.remove(unpack_node) # Here the strided slice can change the output shape of the RNN # so insert a reshape to do the shape change if changes_shape: reshape = ReshapeParameters( unpack_node.name + '_reshape', old_shape=Dim.unnamed(unpack_node.post_slice_shape), shape=Dim.unnamed(unpack_node.out_shape)) G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=reshape, from_idx=in_edge.from_idx)) for out_edge in out_edges: G.add_edge( NNEdge(from_node=reshape, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) if G.quantization: G.quantization[NodeId(reshape)] = G.quantization[NodeId( unpack)] else: for out_edge in out_edges: G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) if G.quantization: del G.quantization[NodeId(unpack_node)] if set_identity: self.set_identity(G) return True
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): nodes = list(G.nodes(node_classes=GlobalPoolingParameters)) modified_graph = False while nodes: node = nodes.pop() node_group = self.reductions(G, node) if len(node_group) <= 1: continue modified_graph = True reduction_axes, new_shape, has_keepdims, _ = reduce( reduce_reduction, node_group, None) new_node = node_group[0] new_node.axis = sorted(list(reduction_axes)) new_node.keep_dims = has_keepdims out_edges = G.out_edges(node_group[-1].name) if G.quantization: last_qrec = G.quantization[NodeId(node_group[-1])] G.quantization[NodeId(new_node)].out_qs = last_qrec.out_qs for node in node_group[1::]: G.remove(node.name) nid = NodeId(node) if G.quantization and nid in G.quantization: del G.quantization[nid] if has_keepdims and len(new_shape) != len( new_node.in_dims[0].shape): rparams = ReshapeParameters( G.unique_name(f'{new_node.name}_reshape'), shape=Dim.unnamed(new_shape)) if G.quantization: G.quantization.copy_qrec(last_qrec, 'out', 0, rparams) G.add_edge(NNEdge(new_node, rparams)) new_node = rparams for edge in out_edges: G.add_edge(NNEdge(new_node, edge.to_node, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 else: valid_activations = VALID_ACTIVATIONS_SQ8 for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]: node_list = self.get_node_list(G, fc_node, valid_activations) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.linear, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = LinearFusionParameters( node_list.linear.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype( qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.linear.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph