Exemplo n.º 1
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for node_set in self.find_sets(G):
            has_modified_graph = True
            in_edges, out_edges, internal_edges = group_edges(G, node_set)
            frag = GraphView()
            for edge in internal_edges:
                frag.add_edge(edge)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            in_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in in_edges
            ]
            out_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in out_edges
            ]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                node_edge_idx[0] for node_edge_idx in in_edges
                if isinstance(node_edge_idx[0], ConstantInputParameters)
            ]
            LOG.info('matched expression - creating expression %s',
                     self._expr_num)
            expr = ExpressionFusionParameters(f"expr_{self._expr_num}",
                                              subgraph=frag,
                                              input_mapping=in_mapping,
                                              output_mapping=out_mapping,
                                              in_dims=in_dims,
                                              out_dims=out_dims,
                                              constant_inputs=constant_inputs)
            in_edge_mapping = list(in_edges.keys())
            out_edge_mapping = [[(edge.to_node, edge.to_idx)
                                 for edge in edge_set]
                                for edge_set in out_edges.values()]

            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=in_edge_mapping,
                edge_out_mapping=out_edge_mapping)
            self._expr_num += 1

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Exemplo n.º 2
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        # collect connected node sets
        node_sets = group_nodes(G, [
            node for node in G.nodes() if isinstance(node, FUSE_NODES) or (
                isinstance(node, ConstantInputParameters)
                and node.out_dims[0].size() == 1)
        ])
        # remove sets that are only ConstantInputs
        node_sets = [
            node_set for node_set in node_sets if not all(
                isinstance(node, ConstantInputParameters) for node in node_set)
        ]
        for node_set in node_sets:
            has_modified_graph = True
            in_edges, out_edges, internal_edges = group_edges(G, node_set)
            frag = GraphView()
            for edge in internal_edges:
                frag.add_edge(edge)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                isinstance(node_edge_idx[0], ConstantInputParameters)
                for node_edge_idx in in_edges
            ]
            expr = ExpressionFusionParameters("expr_%s" % self._expr_num,
                                              subgraph=frag,
                                              input_mapping=in_mapping,
                                              output_mapping=out_mapping,
                                              constant_inputs=constant_inputs)
            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=sorted(list(in_edges.keys()),
                                       key=lambda x: x[1]),
                edge_out_mapping=[[(edge.to_node, edge.to_idx)
                                   for edge in edge_set]
                                  for edge_set in out_edges.values()])

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Exemplo n.º 3
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        replaced = True
        has_modified_graph = False
        while replaced:
            replaced = False
            for subgraph in self.match_function(G):
                # TODO - Save in and out edges here since the replace function may modify the
                # subgraph
                in_edges = [
                    in_edge for input_node in subgraph.inputs()
                    for in_edge in G.in_edges(input_node.name)
                ]
                out_edges = [
                    out_edge for output_node in subgraph.outputs()
                    for out_edge in G.out_edges(output_node.name)
                ]
                try:
                    replacement, edge_in_mapping, edge_out_mapping = self.replace_function(
                        G, subgraph)
                    if replacement is None:
                        G.remove_fragment(subgraph)
                        has_modified_graph = True
                    elif isinstance(replacement, Node):
                        # use saved  in and out edges
                        G.replace_fragment(subgraph,
                                           replacement,
                                           frag_in_edges=in_edges,
                                           frag_out_edges=out_edges,
                                           edge_in_mapping=edge_in_mapping,
                                           edge_out_mapping=edge_out_mapping)
                        has_modified_graph = True
                    else:
                        raise TypeError(
                            "unexcepted return value from replace_function")
                    replaced = True
                    break
                except DontReplaceError:
                    pass

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Exemplo n.º 4
0
 def match(self, G: GraphView, set_identity: bool = True):
     replaced = True
     while replaced:
         replaced = False
         for subgraph in self.match_function(G):
             replacement = self.replace_function(G, subgraph)
             if not replacement:
                 G.remove_fragment(subgraph)
             elif isinstance(replacement, Node):
                 G.replace_fragment(subgraph, replacement)
             else:
                 raise TypeError(
                     "unexcepted return value from replace_function")
             replaced = True
             break
     if set_identity:
         self.set_identity(G)
Exemplo n.º 5
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        to_quantize = []
        node_sets = self.find_sets(G)
        for node_set in node_sets:
            Symbol.set_default_control(SymbolStats())
            has_modified_graph = True
            in_edges, out_edges, internal_edges = group_edges(G, node_set)
            frag = GraphView()
            for node in node_set:
                frag.add_node(node)
            for edge in internal_edges:
                frag.add_edge(edge)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            in_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in in_edges
            ]
            out_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in out_edges
            ]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                node_edge_idx[0] for node_edge_idx in in_edges
                if isinstance(node_edge_idx[0], ConstantInputParameters)
            ]
            LOG.debug(
                "inputs coming from: %s",
                ",".join(f"{from_node.__repr__()}:{from_idx}"
                         for from_node, from_idx in in_edges))
            LOG.info("fusing nodes: %s into expr_%s",
                     ",".join(node.__repr__() for node in node_set),
                     self._expr_num)
            expr = ExpressionFusionParameters(
                G.unique_name(f"expr_{self._expr_num}"),
                subgraph=frag,
                qrecs=G.quantization,
                input_mapping=in_mapping,
                output_mapping=out_mapping,
                in_dims=in_dims,
                out_dims=out_dims,
                constant_inputs=constant_inputs)
            in_edge_mapping = list(in_edges.keys())
            out_edge_mapping = [[(edge.to_node, edge.to_idx)
                                 for edge in edge_set]
                                for edge_set in out_edges.values()]
            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=in_edge_mapping,
                edge_out_mapping=out_edge_mapping,
                edge_class=NNEdge)
            if G.quantization:
                qrecs = G.quantization
                in_qs = [
                    qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]]
                    for in_map in in_mapping
                ]
                out_qs = [
                    qrecs[NodeId(node)].out_qs[idx]
                    for node, idx in out_mapping
                ]
                stats = Symbol.CURRENT_CONTROL.stats
                func_col = expr.func_col
                for idx, qtype in enumerate(in_qs):
                    symbol = func_col.variables[func_col.input_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                for idx, qtype in enumerate(out_qs):
                    symbol = func_col.variables[func_col.output_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                G.quantization[NodeId(expr)] = QRec(in_qs=in_qs,
                                                    out_qs=out_qs,
                                                    expression=stats,
                                                    ktype='scaled')
                # delete any quantize parameters on outputs to allow the quantizer
                # to fuse them into the expression
                out_edges = G.out_edges(expr.name)
                for edge in out_edges:
                    if isinstance(edge.to_node, QuantizeParameters):
                        G.remove_and_reconnect(edge.to_node)
                        if NodeId(edge.to_node) in G.quantization:
                            del G.quantization[NodeId(edge.to_node)]
                to_quantize.append(expr)

            self._expr_num += 1

        if to_quantize:
            quantizer = UnifiedQuantizer.from_quantized_graph(G)
            G.quantization = quantizer.quantize(G, start_nodes=to_quantize)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph