Пример #1
0
    def match(self, G: GraphView, set_identity: bool = True):
        visited_edges = {}
        nodes_to_remove = []
        has_modified_graph = False
        for node in G.inputs():
            # check if constantinput. if is then check if positive and check max value
            if isinstance(node, ConstantInputParameters):
                if node.value is not None:
                    if G.has_quantized_parameters:
                        qrec = G.quantization[NodeId(node)]
                        qtype = qrec.out_qs[0]
                        if hasattr(qtype, 'wrapped'):
                            qtype = qtype.wrapped
                        val = qtype.dequantize(node.value)
                    else:
                        val = node.value
                    if val.min() >= 0:
                        status = (True, val.max())
                    else:
                        status = (False, False)
            else:
                status = (False, False)

            for edge in G.out_edges(node.name):
                visited_edges[edge] = status
                nodes_to_remove += find_redundant_relus(
                    G, edge.to_node, visited_edges)
        for node in nodes_to_remove:
            has_modified_graph = True
            # Only relus so only one in edge
            in_edge = G.in_edges(node.name)[0]
            for edge in G.out_edges(node.name):
                G.add_edge(
                    NNEdge(from_node=in_edge.from_node,
                           from_idx=in_edge.from_idx,
                           to_node=edge.to_node,
                           to_idx=edge.to_idx))
            G.remove(node)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Пример #2
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        visited_edges = {}
        nodes_to_remove = []
        has_modified_graph = False
        for node in G.inputs():
            # check if constantinput. if is then check if positive and check max value
            if isinstance(node, ConstantInputParameters):
                if node.value is not None:
                    val = node.dqvalue
                    if np.min(val) >= 0:
                        status = (True, np.max(val))
                    else:
                        status = (False, False)
                else:
                    status = (False, False)
            else:
                status = (False, False)

            for edge in G.out_edges(node.name):
                visited_edges[edge] = status
                nodes_to_remove += find_redundant_relus(
                    G, edge.to_node, visited_edges)
        for node in nodes_to_remove:
            has_modified_graph = True
            # Only relus so only one in edge
            LOG.info("removing redundant relu %s", node.name)
            in_edge = G.in_edges(node.name)[0]
            out_edges = G.out_edges(node.name)
            G.remove(node)
            for edge in out_edges:
                G.add_edge(NNEdge(from_node=in_edge.from_node,
                                  from_idx=in_edge.from_idx,
                                  to_node=edge.to_node,
                                  to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Пример #3
0
def construct_subgraph(G, nodes):
    subg = GraphView()
    nodes = set(nodes)
    for node in nodes:
        for edge in G.out_edges(node.name):
            # only add internal edges
            if edge.to_node in nodes:
                subg.add_edge(NNEdge(from_node=edge.from_node, to_node=edge.to_node,
                                   from_idx=edge.from_idx, to_idx=edge.to_idx))

    def red_fn(state, edge):
        state.setdefault((edge.from_node, edge.from_idx), []
                         ).append((edge.to_node, edge.to_idx))
        return state

    inputs = reduce(red_fn, set([edge for node in subg.inputs()
                                 for edge in G.in_edges(node.name)]), {})
    inputs_map = []
    for (fnode, fidx), outs in inputs.items():
        inp = FusionInputParameters(
            f'{fnode.name}_{fidx}_in', dims=fnode.out_dims[fidx])
        inputs_map.append((fnode, fidx))
        for (tnode, tidx) in outs:
            subg.add_edge(NNEdge(from_node=inp, to_node=tnode, to_idx=tidx))
    outputs = [(node, set(edge.from_idx for edge in G.out_edges(node.name)))
               for node in subg.outputs()]
    outputs_map = []
    for (node, fidxes) in outputs:
        output_map = []
        outputs_map.append(output_map)
        for fidx in fidxes:
            output_map.append((edge.to_node, edge.to_idx)
                              for edge in G.out_edges(node.name) if edge.from_idx == fidx)
            outp = FusionOutputParameters(
                f'{node.name}_{fidx}_out', dims=node.out_dims[fidx])
            subg.add_edge(NNEdge(from_node=node, to_node=outp, from_idx=fidx))
    return (subg, inputs_map, outputs_map)
Пример #4
0
def new_eval_copy(subg: GraphView) -> CopySet:
    node_slices = {}
    for inp in subg.inputs():
        walk_down(subg, inp, node_slices)
    return CopySet([(node, node_slices[node][0]) for node in subg.outputs()])