def match(self, G: GraphView, set_identity: bool = True): visited_edges = {} nodes_to_remove = [] has_modified_graph = False for node in G.inputs(): # check if constantinput. if is then check if positive and check max value if isinstance(node, ConstantInputParameters): if node.value is not None: if G.has_quantized_parameters: qrec = G.quantization[NodeId(node)] qtype = qrec.out_qs[0] if hasattr(qtype, 'wrapped'): qtype = qtype.wrapped val = qtype.dequantize(node.value) else: val = node.value if val.min() >= 0: status = (True, val.max()) else: status = (False, False) else: status = (False, False) for edge in G.out_edges(node.name): visited_edges[edge] = status nodes_to_remove += find_redundant_relus( G, edge.to_node, visited_edges) for node in nodes_to_remove: has_modified_graph = True # Only relus so only one in edge in_edge = G.in_edges(node.name)[0] for edge in G.out_edges(node.name): G.add_edge( NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx, to_node=edge.to_node, to_idx=edge.to_idx)) G.remove(node) if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): visited_edges = {} nodes_to_remove = [] has_modified_graph = False for node in G.inputs(): # check if constantinput. if is then check if positive and check max value if isinstance(node, ConstantInputParameters): if node.value is not None: val = node.dqvalue if np.min(val) >= 0: status = (True, np.max(val)) else: status = (False, False) else: status = (False, False) else: status = (False, False) for edge in G.out_edges(node.name): visited_edges[edge] = status nodes_to_remove += find_redundant_relus( G, edge.to_node, visited_edges) for node in nodes_to_remove: has_modified_graph = True # Only relus so only one in edge LOG.info("removing redundant relu %s", node.name) in_edge = G.in_edges(node.name)[0] out_edges = G.out_edges(node.name) G.remove(node) for edge in out_edges: G.add_edge(NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx, to_node=edge.to_node, to_idx=edge.to_idx)) if set_identity: self.set_identity(G) return has_modified_graph
def construct_subgraph(G, nodes): subg = GraphView() nodes = set(nodes) for node in nodes: for edge in G.out_edges(node.name): # only add internal edges if edge.to_node in nodes: subg.add_edge(NNEdge(from_node=edge.from_node, to_node=edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx)) def red_fn(state, edge): state.setdefault((edge.from_node, edge.from_idx), [] ).append((edge.to_node, edge.to_idx)) return state inputs = reduce(red_fn, set([edge for node in subg.inputs() for edge in G.in_edges(node.name)]), {}) inputs_map = [] for (fnode, fidx), outs in inputs.items(): inp = FusionInputParameters( f'{fnode.name}_{fidx}_in', dims=fnode.out_dims[fidx]) inputs_map.append((fnode, fidx)) for (tnode, tidx) in outs: subg.add_edge(NNEdge(from_node=inp, to_node=tnode, to_idx=tidx)) outputs = [(node, set(edge.from_idx for edge in G.out_edges(node.name))) for node in subg.outputs()] outputs_map = [] for (node, fidxes) in outputs: output_map = [] outputs_map.append(output_map) for fidx in fidxes: output_map.append((edge.to_node, edge.to_idx) for edge in G.out_edges(node.name) if edge.from_idx == fidx) outp = FusionOutputParameters( f'{node.name}_{fidx}_out', dims=node.out_dims[fidx]) subg.add_edge(NNEdge(from_node=node, to_node=outp, from_idx=fidx)) return (subg, inputs_map, outputs_map)
def new_eval_copy(subg: GraphView) -> CopySet: node_slices = {} for inp in subg.inputs(): walk_down(subg, inp, node_slices) return CopySet([(node, node_slices[node][0]) for node in subg.outputs()])