def replace_function(self, G: NNGraph, subgraph: GraphView): nodes = list(subgraph.nodes()) pnode = ActivationFusion(nodes[0].name + "fusion", nodes[0].op_name + "_active", subgraph) nodes[0].step_idx = 0 nodes[1].step_idx = 1 LOG.debug("fused nodes %s", ",".join((node.name for node in nodes))) if G.quantization: qrecs = G.quantization.get_all(subgraph.nodes()) if qrecs: if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in subgraph.nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec return pnode
def replace_function(self, G: GraphView, subgraph: GraphView): if not self.validate_match(subgraph): raise DontReplaceError() step = 0 for node in subgraph.nodes(): node.step_idx = step step = step + 1 if isinstance(node, Conv2DParameters): conv_name = node.name + "_fusion" break LOG.debug("fused nodes %s", ",".join((node.name for node in subgraph.nodes()))) # simple node order is necessary because nodes() will not necessarily # be in order pnode = ConvFusionParameters(conv_name, fusion_type=self.fusion_type, subgraph=subgraph) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec return pnode, None, None
def replace_function(self, G: NNGraph, subgraph: GraphView): step = 0 for node in subgraph.nodes(): node.step_idx = step step = step + 1 if isinstance(node, FcParameters): linear_name = node.name + "_fusion" break LOG.info("fusing nodes %s", ",".join( (node.name for node in subgraph.nodes()))) # simple node order is necessary because nodes() will not necessarily # be in order pnode = ConvFusionParameters(linear_name, fusion_type="linear_active", subgraph=subgraph) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec return pnode, None, None
def replace_function(self, G: GraphView, subgraph: GraphView): for node in subgraph.nodes(): if isinstance(node, Conv2DParameters): conv_name = node.name + "_fusion" break LOG.debug("fused nodes %s", ",".join( (node.name for node in subgraph.nodes()))) # simple node order is necessary because nodes() will not necessarily # be in order return FusionParameters(conv_name, self.fusion_type, [node for node in subgraph.dfs()])
def replace_function(self, G: GraphView, subgraph: GraphView): step = 0 for node in subgraph.nodes(): node.step_idx = step step = step + 1 if isinstance(node, FcParameters): linear_name = node.name + "_fusion" break LOG.debug("fused nodes %s", ",".join( (node.name for node in subgraph.nodes()))) # simple node order is necessary because nodes() will not necessarily # be in order return FusionParameters(linear_name, "linear_active", subgraph)
def replace_function(self, G: GraphView, subgraph: GraphView): if not self.validate_match(subgraph): raise DontReplaceError() step = 0 for node in subgraph.nodes(): node.step_idx = step step = step + 1 if isinstance(node, Conv2DParameters): conv_name = node.name + "_fusion" break LOG.debug("fused nodes %s", ",".join( (node.name for node in subgraph.nodes()))) # simple node order is necessary because nodes() will not necessarily # be in order return FusionParameters(conv_name, self.fusion_type, subgraph)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())): node_list = self.get_node_list(G, node, FusionMatch(self._default_ktype)) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for snode in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=snode)) last_node = snode # assumption here is that the first node could have multiple inputs but definitely has only # one output input_mapping = [[ (node_list.node, idx) ] for idx in range(G.num_in_edges(node_list.node.name))] output_mapping = [(last_node, 0)] pnode = node_list.fusions_class(node_list.node.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype(qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for fnode in pnode.contained_nodes(): G.quantization.move_to_fusion(fnode, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.node.name) out_edges = G.out_edges(last_node.name) for snode in node_list.order: G.remove(snode) for edge in in_edges: G.add_edge( NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge( NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): if not G.quantization: return input_dict = {} for node in G.nodes(): if not self.can_change_output(node): continue all_matches = [] for succ in [ succ for succs in G.successors(node.name) for succ in succs ]: matches = self.can_change_input(G, succ) if matches is None: all_matches = None break all_matches += matches if all_matches is None: continue input_dict[node] = all_matches input_dict = self.validate_multi_input(G, input_dict) for node in input_dict: # all nodes that can currently change output have one output self.do_change(G, node) if set_identity: self.set_identity(G)
def match(self, G: GraphView, set_identity: bool = True): split_nodes = [ node for node in G.nodes() if isinstance(node, SplitParameters) ] has_modified_graph = False for node in split_nodes: # traverse reshapes or transposes that do nothing - check gen # find edges connected to concats res = self.find_split_concat(G, node) if res is None: continue # TODO(martin) - group edges that have adjacent inputs and outputs if G.quantization: qrec = G.quantization[NodeId(node)] for idx, bundle in enumerate(res): if not bundle: continue has_modified_graph = True copy_node = CopyParameters("%s_copy_%s" % (node.name, idx)) for edge_set in bundle: first_edge = edge_set[0] G.remove_edge(first_edge) G.add_edge( NNEdge(copy_node, first_edge.to_node, to_idx=first_edge.to_idx)) G.add_edge(NNEdge(node, copy_node, from_idx=idx)) if G.quantization: G.quantization[NodeId(copy_node)] = qrec.__class__( in_qs=deepcopy(qrec.out_qs[idx]), out_qs=deepcopy(qrec.out_qs[idx])) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False for node in G.nodes(node_classes=MatMulOpParameters): in_edges = [edge for edge in G.indexed_in_edges(node.name)] trans_node = in_edges[1].from_node if not isinstance(trans_node, TransposeParameters): continue if isinstance(node, MatMulTransposedParameters): new_node = MatMulOpParameters(node.name) else: new_node = MatMulTransposedParameters(node.name) in_trans_edge = [ edge for edge in G.indexed_in_edges(trans_node.name) ][0] G.replace_node(node.name, new_node) G.remove(trans_node) G.add_edge( NNEdge(in_trans_edge.from_node, new_node, from_idx=in_trans_edge.from_idx, to_idx=1)) has_modified_graph = True if set_identity: self.set_identity(G) return has_modified_graph
def match(self, G: GraphView, set_identity: bool = True): if not G.quantization: return sigs_swishes = [ node for node in G.nodes() if isinstance(node, (SigmoidActivationParameters, HSigmoidActivationParameters, HSwishActivationParameters)) ] qrecs = [G.quantization[NodeId(node)] for node in sigs_swishes] for sig_swish, qrec in zip(sigs_swishes, qrecs): in_edge = [ edge for edge in G.in_edges(sig_swish.name) if edge.to_idx == 0 ][0] in_q = qrec.in_qs[0] min_val, max_val = in_q.min_val, in_q.max_val if isinstance( sig_swish, (HSigmoidActivationParameters, SigmoidActivationParameters)): # Hard sigmoid implements a RELU, be sure 6 can be representable min_val, max_val = 0, 6 elif isinstance(sig_swish, HSwishActivationParameters): min_val, max_val = 0, in_q.max_val * 6 new_in_q = QType.from_min_max_sq(min_val=min_val, max_val=max_val, dtype=in_q.dtype) propagate_qtype_up(G, new_in_q, in_edge) if set_identity: self.set_identity(G) return False
def match(self, G: GraphView, set_identity: bool = True): # get a list of all the nodes that are transposable but not transposes # Need to do this first to avoid mutating it when doing the modifications tnodes = list(filter(lambda n: isinstance(n, Transposable) and\ not isinstance(n, TransposeParameters), G.nodes())) for node in tnodes: if node.transpose_in: for idx, edge in enumerate(G.in_edges(node.name)): in_params = TransposeParameters("%s_TIN_%s" % (node.name, idx), transpose=node.transpose_in) if node.in_dims_hint: in_hint = node.in_dims_hint[edge.to_idx] out_hint = apply_reverse_transpose_to_hint(in_hint, node.transpose_in) in_params.in_dims_hint = [in_hint.copy()] in_params.out_dims_hint = [out_hint.copy()] node.in_dims_hint[edge.to_idx] = out_hint G.insert_node(in_params, edge.from_node.name, edge.to_node.name, from_idx=edge.from_idx, to_idx=edge.to_idx) node.transpose_in = None if node.transpose_out: for idx, edge in enumerate(G.out_edges(node.name)): out_params = TransposeParameters("%s_TOUT_%s" % (node.name, idx), transpose=node.transpose_out) if node.out_dims_hint: out_hint = node.out_dims_hint[edge.from_idx] in_hint = apply_reverse_transpose_to_hint(out_hint, node.transpose_out) out_params.in_dims_hint = [in_hint.copy()] out_params.out_dims_hint = [out_hint.copy()] node.out_dims_hint[edge.from_idx] = in_hint G.insert_node(out_params, edge.from_node.name, edge.to_node.name, from_idx=edge.from_idx, to_idx=edge.to_idx) node.transpose_out = None if set_identity: self.set_identity(G)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for node in [ node for node in G.nodes() if self.node_does_nothing(G, node) ]: has_modified_graph = True in_edge = G.in_edges(node.name)[0] G.remove_edge(in_edge) for out_edge in G.out_edges(node.name): G.remove_edge(out_edge) G.add_edge( NNEdge(in_edge.from_node, out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) LOG.info(f'removing {node.name} that does nothing') G.remove(node) if set_identity: self.set_identity(G) return has_modified_graph
def replace_function(self, G: GraphView, subgraph: GraphView): relu_node = None constant_node = None mul_node = None for node in subgraph.nodes(): if isinstance(node, ReluActivationParameters): relu_node = node elif isinstance(node, ConstantInputParameters): constant_node = node elif isinstance(node, MatrixMulParameters): mul_node = node activation = HSigmoidActivationParameters(mul_node.name + "_fused_close_hsigmoid", offset=0) if G.quantization: reluqrec = G.quantization[NodeId(relu_node)] mulqrec = G.quantization[NodeId(mul_node)] del G.quantization[NodeId(constant_node)] pqrec = QRec.copy_ktype(reluqrec, in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) G.quantization[NodeId(activation)] = pqrec return activation, None, None
def replace_function(self, G: GraphView, subgraph: GraphView): relu_node = None constant_node = None mul_node = None for node in subgraph.nodes(): if isinstance(node, ReluActivationParameters): relu_node = node elif isinstance(node, ConstantInputParameters): constant_node = node elif isinstance(node, MatrixMulParameters): mul_node = node activation = HSigmoidActivationParameters(mul_node.name + "_fused_close_hsigmoid", offset=0) if G.quantization: reluqrec = G.quantization[NodeId(relu_node)] mulqrec = G.quantization[NodeId(mul_node)] del G.quantization[NodeId(constant_node)] if isinstance(reluqrec, (SymmetricQuantizationRecord)): pqrec = SymmetricQuantizationRecord(in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) elif isinstance(reluqrec, (MultQuantizationRecord)): pqrec = MultQuantizationRecord(in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) elif isinstance(reluqrec, (Float32QuantizationRecord)): pqrec = Float32QuantizationRecord(in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs) else: raise NotImplementedError() G.quantization[NodeId(activation)] = pqrec return activation, None, None
def match(self, G: GraphView, set_identity: bool = True): if not G.quantization: return softmaxes = [ node for node in G.nodes() if isinstance(node, SoftMaxParameters) ] qrecs = [G.quantization[NodeId(node)] for node in softmaxes] if not all(isinstance(qrec, MultQuantizationRecord) for qrec in qrecs): return for softmax, qrec in zip(softmaxes, qrecs): in_q = qrec.in_qs[0] in_q.scale_to_pow2() for edge in G.in_edges(softmax.name): propagate_qtype_up(G, in_q, edge) for edge in G.out_edges(softmax.name): assert isinstance( edge.to_node, (OutputParameters, QuantizeParameters )), "Softmax is supported only at the end of the graph" out_qrec = G.quantization[NodeId(edge.to_node)] out_qrec.in_qs[0] = qrec.out_qs[0] out_qrec.out_qs[0] = qrec.out_qs[0] if set_identity: self.set_identity(G) return False
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): split_nodes = [node for node in G.nodes(node_classes=SplitParameters)] has_modified_graph = False for node in split_nodes: has_modified_graph = self.find_direct_connects( G, node, has_modified_graph) concat_nodes = [ node for node in G.nodes(node_classes=ConcatParameters) ] for node in concat_nodes: has_modified_graph = self.find_direct_connects(G, node, has_modified_graph, find_output=False) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False for node in [ node for node in G.nodes(node_classes=StridedSliceParameters) ]: if node.slice_shape != tuple(node.in_dims[0].shape): continue has_modified_graph = True nid = NodeId(node) if node.slice_shape == node.out_shape: LOG.info( f'removing strided slice {node.name} that does nothing') G.remove_and_reconnect(node, edge_class=NNEdge) if G.quantization and nid in G.quantization: del G.quantization[nid] else: reshape = ReshapeParameters( G.unique_name(f'{node.name}_reshape'), old_shape=node.slice_shape, shape=node.out_shape) LOG.info( f'replacing strided slice {node.name} with reshape {reshape.name}' ) G.replace_node(node, reshape) if G.quantization and nid in G.quantization: G.quantization[NodeId(reshape)] = G.quantization[nid] del G.quantization[nid] if set_identity: self.set_identity(G) return has_modified_graph
def replace_function(self, G: GraphView, subgraph: GraphView): filter_node = None constant_node = None for node in subgraph.nodes(): if isinstance(node, FilterParameters): filter_node = node elif isinstance(node, ConstantInputParameters): constant_node = node LOG.info("fusing bias in %s into %s", constant_node.name, filter_node.name) flattened_constant = constant_node.value.flatten() # shape needs to match if flattened_constant.shape[0] == filter_node.filter.out_c: if filter_node.has_bias: assert filter_node.biases is not None, "can't absorb bias into filter. maybe weights are not loaded" filter_node.biases += flattened_constant else: filter_node.biases = flattened_constant else: raise DontReplaceError() if G.quantization: fnid = NodeId(filter_node) cnid = NodeId(constant_node) if fnid in G.quantization and cnid in G.quantization: G.quantization[fnid].biases_q = G.quantization[cnid].out_qs[0] return filter_node, None, None
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): nodes_removed = [] modified_graph = False for node in G.nodes(node_classes=QuantizeParameters): if issubclass(node.from_qtype.dtype, (np.floating, bfloat16)): if issubclass(node.to_qtype.dtype, (np.floating, bfloat16)): LOG.warning( 'node %s quantizes from floating type to floating type and cannot directly be removed', node.name) continue if self.propagate_up(G, node, node.to_qtype): modified_graph = True nodes_removed.append(node) G.remove_and_reconnect(node, edge_class=NNEdge) if G.quantization: del G.quantization[NodeId(node)] else: LOG.warning('unable to remove quantize node %s', node.name) else: if self.propagate_down(G, node, node.from_qtype): modified_graph = True nodes_removed.append(node) G.remove_and_reconnect(node, edge_class=NNEdge) if G.quantization: del G.quantization[NodeId(node)] else: LOG.warning('unable to remove quantize node %s', node.name) if set_identity: self.set_identity(G) return modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = False candidates = [node for node in G.nodes() if len(G.indexed_out_edges(node.name)) == 1 and len(G.out_edges(node.name)) > 1] while candidates: node = candidates.pop(0) strings = self.explore(G, [node]) if not strings: continue modified_graph = True primary = strings.pop(0) for pnode in primary: if pnode in candidates: candidates.remove(pnode) out_edges = [] for other in strings: out_edges.extend(G.out_edges(other[-1].name)) for other_node in other: if other_node in candidates: candidates.remove(other_node) G.remove(other_node) nid = NodeId(other_node) if G.quantization and nid in G.quantization: del G.quantization[nid] LOG.info( f'removed duplicates from {primary[0].name} {",".join(node.name for node in other)}') pend = primary[-1] for edge in out_edges: G.add_edge( NNEdge(from_node=pend, to_node=edge.to_node, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return modified_graph
def match(self, G: GraphView, set_identity: bool = True): has_modified = False for node in G.nodes(node_classes=ConstantInputParameters): out_edges = G.out_edges(node.name) if len(out_edges) <= 1: continue has_modified = True LOG.info( 'node %s has more than one out edge and will be duplicated', node.name) idx = 1 for out_edge in out_edges[1::]: new_constant = ConstantInputParameters(f'{node.name}_{idx}', dims=Dim.unnamed( node.dims.shape), value=node.value.copy()) G.remove_edge(out_edge) G.add_edge( NNEdge(from_node=new_constant, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) idx += 1 if set_identity: self.set_identity(G) return has_modified
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): const_ops = [node for node in G.nodes() if isinstance(node, MatrixMulParameters) and any([isinstance(edge.from_node, ConstantInputParameters) and check_equals(G, edge.from_node, 1.0/6.0) for edge in G.in_edges(node.name)])] oprecs = [oprec for oprec in (look_back(G, op) for op in const_ops) if oprec is not None] has_modified_graph = False for oprec in oprecs: mul_edge = G.out_edges(oprec['mul'][0].name) if len(mul_edge) == 1: mul_edge = mul_edge[0] if isinstance(mul_edge.to_node, ReluActivationParameters): oprec['relu3'] = (mul_edge.to_node, G.quantization[NodeId(mul_edge.to_node)]) has_modified_graph = True process_rec(G, oprec) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): rnn_nodes = [ self.find_unpack(G, node) for node in G.nodes() if isinstance(node, RNNBaseParameters) and node.n_output_cells > 1 ] rnn_nodes_by_slice = self.validate_slices(G, rnn_nodes) rnn_nodes_by_slice = self.validate_multi_branch(G, rnn_nodes_by_slice) if not rnn_nodes_by_slice: return False for unpack_node, rnn_unpacks in rnn_nodes_by_slice.items(): modified_nodes = set() for rnn_unpack in rnn_unpacks: self.process_path(G, rnn_unpack, modified_nodes) # since process path will have removed all unnecessary nodes the edges will be correct here out_edges = G.out_edges(unpack_node.name) in_edges = G.in_edges(unpack_node.name) assert len(in_edges ) == 1, "expecting unpack node to have only one in edge" in_edge = in_edges[0] changes_shape = unpack_node.changes_shape if isinstance( unpack_node, StridedSliceParameters) else False LOG.info("Eliminating last cell unpack: %s", unpack_node.name) G.remove(unpack_node) # Here the strided slice can change the output shape of the RNN # so insert a reshape to do the shape change if changes_shape: reshape = ReshapeParameters( unpack_node.name + '_reshape', old_shape=Dim.unnamed(unpack_node.post_slice_shape), shape=Dim.unnamed(unpack_node.out_shape)) G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=reshape, from_idx=in_edge.from_idx)) for out_edge in out_edges: G.add_edge( NNEdge(from_node=reshape, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) if G.quantization: G.quantization[NodeId(reshape)] = G.quantization[NodeId( unpack)] else: for out_edge in out_edges: G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) if G.quantization: del G.quantization[NodeId(unpack_node)] if set_identity: self.set_identity(G) return True
def match(self, G: GraphView, set_identity: bool = True): has_modified_graph = False for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]: node_list = self.get_node_list(G, conv_node) if node_list is None or len(node_list.order) < 2: continue if node_list.fusion_type == 'conv_active_pool': if node_list.pool.pool_type == "average": node_list.order = node_list.order[:2:] node_list.pool = None elif node_list.fusion_type == 'conv_pool_active': if node_list.pool.pool_type == "average" and node_list.active.activation != "relu": continue LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge(NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.conv, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = ConvFusionParameters( node_list.conv.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, in_dims_hint=node_list.conv.in_dims_hint, out_dims_hint=node_list.conv.out_dims_hint, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = None if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)): prec = SymmetricQuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)): prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)): prec = Float32QuantizationRecord( in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.conv.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): something_changed = False filt_nodes = [node for node in G.nodes() if isinstance(node, (Conv2DParameters, ConvFusionParameters))] for filt_node in filt_nodes: pnode = filt_node if isinstance(filt_node, ConvFusionParameters): cnodes = filt_node.contained_nodes() filt_node = cnodes[0] if not isinstance(filt_node, Conv2DParameters): continue in_dim = filt_node.in_dims filt_dim = filt_node.filter if filt_dim.h <= in_dim[0].h and filt_dim.w <= in_dim[0].w: continue min_h = min(filt_dim.h, in_dim[0].h) min_w = min(filt_dim.w, in_dim[0].w) if min_h > 1 and min_w > 1: LOG.warning("Filter of %s [%dx%d] bigger than input [%dx%d] not optimal but will work on AT", filt_node.name, filt_dim.h, filt_dim.w, in_dim[0].h, in_dim[0].w) continue ker_h = 1 if min_h == 1 else filt_dim.h ker_w = 1 if min_w == 1 else filt_dim.w if ker_h == filt_dim.h and ker_w == filt_dim.w: continue new_filt_dim = Conv2DFilterDim( ker_h, ker_w, filt_dim.out_c, in_c=filt_dim.in_c) LOG.warning("Converting filter of %s from [%dx%d] -> [%dx%d]", filt_node.name, filt_dim.h, filt_dim.w, new_filt_dim.h, new_filt_dim.w) filt_node.filter = new_filt_dim new_w_idxs = [] for dim in filt_dim.order: if dim in ('out_c', 'in_c'): new_w_idxs.append(slice(None)) elif dim == 'h': if new_filt_dim.h == 1: new_w_idxs.append( slice(filt_node.padding.t, filt_node.padding.t + 1)) else: new_w_idxs.append(slice(0, new_filt_dim.h)) elif dim == 'w': if new_filt_dim.w == 1: new_w_idxs.append( slice(filt_node.padding.l, filt_node.padding.l + 1)) else: new_w_idxs.append(slice(0, new_filt_dim.w)) weights_node = G.indexed_in_edges(pnode.name)[1].from_node weights_node.value = weights_node.value[tuple(new_w_idxs)] weights_node.dims = Dim.unnamed(weights_node.value.shape) something_changed = True if set_identity: self.set_identity(G) return something_changed
def match(self, G: GraphView, set_identity: bool = True): # get a list of all the nodes that are transposable but not transposes # Need to do this first to avoid mutating it when doing the modifications tnodes = list(filter(lambda n: isinstance(n, Transposable) and not isinstance(n, TransposeParameters), G.nodes())) has_modified_graph = False for node in tnodes: if node.transpose_in: for idx, edge in enumerate(G.in_edges(node.name)): if edge.to_idx >= len(node.transpose_in): continue trans = node.transpose_in[edge.to_idx] if trans is None: continue has_modified_graph = True in_params = TransposeParameters("%s_TIN_%s" % (node.name, idx), transpose=trans) if node.in_dims_hint and node.in_dims_hint[edge.to_idx]: in_hint = node.in_dims_hint[edge.to_idx] out_hint = apply_reverse_transpose_to_hint(in_hint, trans) in_params.in_dims_hint = [in_hint.copy()] in_params.out_dims_hint = [out_hint.copy()] node.in_dims_hint[edge.to_idx] = out_hint if G.quantization: G.quantization.copy_to_node(node, in_params) G.insert_node(in_params, edge.from_node.name, edge.to_node.name, from_idx=edge.from_idx, to_idx=edge.to_idx, edge_class=NNEdge) node.transpose_in = None if node.transpose_out: for idx, edge in enumerate(G.out_edges(node.name)): if edge.from_idx >= len(node.transpose_out): continue trans = node.transpose_out[edge.from_idx] if trans is None: continue has_modified_graph = True out_params = TransposeParameters("%s_TOUT_%s" % (node.name, idx), transpose=trans) if node.out_dims_hint: out_hint = node.out_dims_hint[edge.from_idx] in_hint = apply_reverse_transpose_to_hint(out_hint, trans) out_params.in_dims_hint = [in_hint.copy()] out_params.out_dims_hint = [out_hint.copy()] node.out_dims_hint[edge.from_idx] = in_hint if G.quantization: G.quantization.copy_to_node(node, out_params) G.insert_node(out_params, edge.from_node.name, edge.to_node.name, from_idx=edge.from_idx, to_idx=edge.to_idx, edge_class=NNEdge) node.transpose_out = None if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = True while modified_graph: modified_graph = False for reshape in G.nodes(node_classes=(ReshapeParameters, )): if not reshape.has_transpose and reshape.shape.shape == reshape.old_shape.shape: modified_graph = True LOG.info('removing reshape that does nothing %s', reshape.name) G.remove_and_reconnect(reshape, edge_class=NNEdge) nid = NodeId(reshape) if G.quantization and nid in G.quantization: del G.quantization[nid] res = None for reshape in G.nodes(node_classes=(ReshapeParameters, )): res = self.validate_reshape(G, reshape) if res: LOG.info('unnecessary reshape found after %s', reshape.name) modified_graph = True (reshape, candidates, out_shape) = res for candidate in candidates: LOG.info( 'removing unnecessary reshape or transpose %s', candidate.name) edges = G.out_edges(candidate.name) G.remove(candidate) nid = NodeId(candidate) if G.quantization and nid in G.quantization: del G.quantization[nid] for edge in edges: G.add_edge( NNEdge(from_node=reshape, to_node=edge.to_node, to_idx=edge.to_idx)) reshape.shape = Dim.unnamed(out_shape) break if set_identity: self.set_identity(G) return modified_graph
def split_down_from(cur_g, node, res_g=None): """ split cur_g into 2 graphs. Everything from node down and the rest """ if res_g is None: res_g = GraphView() out_edges = cur_g.out_edges(node.name) cur_g.remove(node) if node not in res_g.nodes(): res_g.add_node(node) for edge in out_edges: res_g.add_edge(edge.clone()) split_down_from(cur_g, edge.to_node, res_g=res_g) return res_g
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False group_identity = kwargs.get('group_identity') if group_identity == 'pow2_match_group': valid_activations = VALID_ACTIVATIONS_POW2 else: valid_activations = VALID_ACTIVATIONS_SQ8 for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]: node_list = self.get_node_list(G, fc_node, valid_activations) if node_list is None or len(node_list.order) < 2: continue LOG.info("fusing nodes %s", ",".join( (node.name for node in node_list.order))) has_modified_graph = True subgraph = GraphView() last_node = None for node in node_list.order: if last_node is not None: subgraph.add_edge( NNEdge(from_node=last_node, to_node=node)) last_node = node input_mapping = [[(node_list.linear, idx)] for idx in range(3)] output_mapping = [(last_node, 0)] pnode = LinearFusionParameters( node_list.linear.name + '_fusion', fusion_type=node_list.fusion_type, subgraph=subgraph, input_mapping=input_mapping, output_mapping=output_mapping) if G.quantization: # TODO - stats qrecs = G.quantization.get_all(pnode.contained_nodes()) if qrecs: prec = QRec.copy_ktype( qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs) for node in pnode.contained_nodes(): G.quantization.move_to_fusion(node, pnode) G.quantization[NodeId(pnode)] = prec in_edges = G.in_edges(node_list.linear.name) out_edges = G.out_edges(last_node.name) for node in node_list.order: G.remove(node) for edge in in_edges: G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx)) for edge in out_edges: G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph