Ejemplo n.º 1
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):

        candidates = [
            node for node in G.nodes(node_classes=SplitParameters)
            if search_up_for_input(G, node)
        ]
        has_modified_graph = False
        for node in candidates:
            LOG.info("Insert copy on split input %s", node.name)
            has_modified_graph = True
            cnode = CopyParameters(G.unique_name(f'{node.name}_copy'))
            G.insert_node_at_edge(cnode, G.in_edges(node.name)[0])
            if G.quantization:
                G.quantization.copy_qrec(node, 'in', 0, cnode)
        if set_identity:
            self.set_identity(G)
        return has_modified_graph
Ejemplo n.º 2
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):

        candidates = [node for node in G.nodes(node_classes=(SplitParameters, ConcatParameters))]
        need_a_copy_edges = []
        for node in candidates:
            for idx, edge in enumerate(G.indexed_in_edges(node.name)):
                real_from_node, _ = find_real_in_edge(G, edge)
                if isinstance(real_from_node, (InputParameters, ConstantInputParameters)):
                    need_a_copy_edges.append((edge, idx))
        has_modified_graph = False
        for edge in need_a_copy_edges:
            LOG.info(
                "Insert copy on split input %s", edge[0].to_node.name)
            has_modified_graph = True
            cnode = CopyParameters(G.unique_name(f'{edge[0].to_node.name}_copy'))
            G.insert_node_at_edge(cnode, edge[0])
            if G.quantization:
                G.quantization.copy_qrec(edge[0].to_node, 'in', 0, cnode)
        if set_identity:
            self.set_identity(G)
        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        for node in G.nodes(node_classes=SplitParameters):
            same_op_edges = self.moveable_same_operation_edges(G, node)
            if not same_op_edges:
                continue
            has_modified_graph = True
            in_edges = G.in_edges(node.name)
            assert len(in_edges) == 1
            # sort by name to ensure that operation is repeatable
            same_op_edges.sort(key=lambda x: x.to_node.name)
            keep_node = same_op_edges[0].to_node
            LOG.info('split node %s has duplicate operations on its out edges',
                     node.name)
            LOG.info('moving %s before split node %s', keep_node.name,
                     node.name)
            for edge in G.out_edges(node.name):
                node_out_edges = G.out_edges(edge.to_node.name)
                G.remove(edge.to_node)
                if edge.to_node != keep_node:
                    LOG.info('deleting duplicate node %s', edge.to_node.name)
                    if G.quantization:
                        nid = NodeId(edge.to_node)
                        if nid in G.quantization:
                            del G.quantization[nid]
                for out_edge in node_out_edges:
                    G.add_edge(
                        NNEdge(from_node=node,
                               from_idx=edge.from_idx,
                               to_node=out_edge.to_node,
                               to_idx=out_edge.to_idx))
            G.insert_node_at_edge(keep_node, in_edges[0], edge_class=NNEdge)
            if G.quantization:
                quantizer = NewQuantizer.from_quantized_graph(G)
                quantizer.quantize()

        if set_identity:
            self.set_identity(G)

        return has_modified_graph