def match(self, G: GraphView, set_identity: bool = True): if not G.quantization: return input_dict = {} for node in G.nodes(): if not self.can_change_output(node): continue all_matches = [] for succ in [ succ for succs in G.successors(node.name) for succ in succs ]: matches = self.can_change_input(G, succ) if matches is None: all_matches = None break all_matches += matches if all_matches is None: continue input_dict[node] = all_matches input_dict = self.validate_multi_input(G, input_dict) for node in input_dict: # all nodes that can currently change output have one output self.do_change(G, node) if set_identity: self.set_identity(G)
def match(self, G: GraphView, set_identity: bool = True): if not G.quantization: return for nid in [nid for nid, qrec in G.quantization.sorted_iterator(G) if qrec is None or not (qrec.in_qs and qrec.out_qs)]: if nid.fnode_name: LOG.warning("can't add quantization to fused node %s", nid.fnode_name) continue if nid.node_name not in G: # previous fusions may have removed nodes from the graph continue node = nid.get_node(G) predecessors = [NodeId(pred) for pred in G.predecessors(node.name)] successors = [NodeId(succ) for succs in G.successors(node.name) for succ in succs] go_back = not successors or (predecessors and all(pred in G.quantization for pred in predecessors)) go_forward = not predecessors or (successors and all(succ in G.quantization for succ in successors)) if not (go_back or go_forward): LOG.warning("node %s is not connected to anything and has no quantization", node.name) continue if go_forward: out_qrecs = set(G.quantization[nid] for nid in successors) if not all(isinstance(out_qrec, MultQuantizationRecord) for out_qrec in out_qrecs): continue out_qtypes = reduce_qtypes([(edge.from_idx, G.quantization[NodeId(edge.to_node)].in_qs[edge.to_idx]) for edge in G.out_edges(node.name)]) else: out_qtypes = None if go_back: in_qrecs = set(G.quantization[nid] for nid in predecessors) if not all(isinstance(in_qrec, MultQuantizationRecord) for in_qrec in in_qrecs): continue in_qtypes = reduce_qtypes([(edge.to_idx, G.quantization[NodeId(edge.from_node)].out_qs[edge.from_idx]) for edge in G.in_edges(node.name)]) else: in_qtypes = None if not in_qtypes: if not predecessors: LOG.info("setting quantization on input node %s", node.name) qrec = MultQuantizationRecord(in_qs=deepcopy(out_qtypes), out_qs=deepcopy(out_qtypes)) else: raise NotImplementedError("propagating qrecs not implemented") elif not out_qtypes: if not successors: LOG.info("setting quantization on output node %s", node.name) qrec = MultQuantizationRecord(in_qs=deepcopy(in_qtypes), out_qs=deepcopy(in_qtypes)) else: raise NotImplementedError("propagating qrecs not implemented") else: LOG.info("setting quantization on node %s", node.name) qrec = MultQuantizationRecord(in_qs=deepcopy(in_qtypes), out_qs=deepcopy(out_qtypes)) G.quantization[nid] = qrec if set_identity: self.set_identity(G) return False