def match(self, G: GraphView, set_identity: bool = True):
        if not G.quantization:
            return
        input_dict = {}
        for node in G.nodes():
            if not self.can_change_output(node):
                continue
            all_matches = []
            for succ in [
                    succ for succs in G.successors(node.name) for succ in succs
            ]:
                matches = self.can_change_input(G, succ)
                if matches is None:
                    all_matches = None
                    break
                all_matches += matches
            if all_matches is None:
                continue
            input_dict[node] = all_matches

        input_dict = self.validate_multi_input(G, input_dict)
        for node in input_dict:
            # all nodes that can currently change output have one output
            self.do_change(G, node)

        if set_identity:
            self.set_identity(G)
Пример #2
0
    def match(self, G: GraphView, set_identity: bool = True):
        if not G.quantization:
            return
        for nid in [nid for nid, qrec in G.quantization.sorted_iterator(G) if qrec is None or not (qrec.in_qs and qrec.out_qs)]:
            if nid.fnode_name:
                LOG.warning("can't add quantization to fused node %s", nid.fnode_name)
                continue
            if nid.node_name not in G:
                # previous fusions may have removed nodes from the graph
                continue

            node = nid.get_node(G)
            predecessors = [NodeId(pred) for pred in G.predecessors(node.name)]
            successors = [NodeId(succ) for succs in G.successors(node.name) for succ in succs]
            go_back = not successors or (predecessors and all(pred in G.quantization for pred in predecessors))
            go_forward = not predecessors or (successors and all(succ in G.quantization for succ in successors))

            if not (go_back or go_forward):
                LOG.warning("node %s is not connected to anything and has no quantization", node.name)
                continue

            if go_forward:
                out_qrecs = set(G.quantization[nid] for nid in successors)
                if not all(isinstance(out_qrec, MultQuantizationRecord) for out_qrec in out_qrecs):
                    continue
                out_qtypes = reduce_qtypes([(edge.from_idx, G.quantization[NodeId(edge.to_node)].in_qs[edge.to_idx])
                                            for edge in G.out_edges(node.name)])
            else:
                out_qtypes = None
            if go_back:
                in_qrecs = set(G.quantization[nid] for nid in predecessors)
                if not all(isinstance(in_qrec, MultQuantizationRecord) for in_qrec in in_qrecs):
                    continue
                in_qtypes = reduce_qtypes([(edge.to_idx, G.quantization[NodeId(edge.from_node)].out_qs[edge.from_idx])
                                           for edge in G.in_edges(node.name)])
            else:
                in_qtypes = None

            if not in_qtypes:
                if not predecessors:
                    LOG.info("setting quantization on input node %s", node.name)
                    qrec = MultQuantizationRecord(in_qs=deepcopy(out_qtypes), out_qs=deepcopy(out_qtypes))
                else:
                    raise NotImplementedError("propagating qrecs not implemented")
            elif not out_qtypes:
                if not successors:
                    LOG.info("setting quantization on output node %s", node.name)
                    qrec = MultQuantizationRecord(in_qs=deepcopy(in_qtypes), out_qs=deepcopy(in_qtypes))
                else:
                    raise NotImplementedError("propagating qrecs not implemented")
            else:
                LOG.info("setting quantization on node %s", node.name)
                qrec = MultQuantizationRecord(in_qs=deepcopy(in_qtypes), out_qs=deepcopy(out_qtypes))

            G.quantization[nid] = qrec

        if set_identity:
            self.set_identity(G)

        return False