Exemplo n.º 1
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        if set_identity:
            self.set_identity(G)
        seq = Sequence(
            ReshapeParameters,
            OneOrMoreOf(AnyOf(ActivationParameters, CustomMatcher())),
            AnyOf(ReshapeParameters, FcParameters))
        for nodes in seq.find(G):
            in_pattern = nodes[0].exp_red_pattern()
            if not in_pattern:
                continue
            has_modified_graph = True
            remove_both = isinstance(
                nodes[-1], ReshapeParameters
            ) and nodes[0].old_shape.shape == nodes[-1].shape.shape
            propagate_shape(G, nodes[0].shape, in_pattern, nodes[1:-1:])
            LOG.info('removing unnecessary reshape %s', nodes[0].name)
            G.remove_and_reconnect(nodes[0], edge_class=NNEdge)
            self.remove_quantization(G, nodes[0])
            if remove_both:
                LOG.info('removing unnecessary reshape %s', nodes[-1].name)
                G.remove_and_reconnect(nodes[-1], edge_class=NNEdge)
                self.remove_quantization(G, nodes[-1])
            elif isinstance(nodes[-1], ReshapeParameters):
                nodes[-1].old_shape = nodes[0].shape

        return has_modified_graph
Exemplo n.º 2
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for node in [
                node for node in G.nodes(node_classes=StridedSliceParameters)
        ]:
            if node.slice_shape != tuple(node.in_dims[0].shape):
                continue
            has_modified_graph = True
            nid = NodeId(node)
            if node.slice_shape == node.out_shape:
                LOG.info(
                    f'removing strided slice {node.name} that does nothing')
                G.remove_and_reconnect(node, edge_class=NNEdge)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]
            else:
                reshape = ReshapeParameters(
                    G.unique_name(f'{node.name}_reshape'),
                    old_shape=node.slice_shape,
                    shape=node.out_shape)
                LOG.info(
                    f'replacing strided slice {node.name} with reshape {reshape.name}'
                )
                G.replace_node(node, reshape)
                if G.quantization and nid in G.quantization:
                    G.quantization[NodeId(reshape)] = G.quantization[nid]
                    del G.quantization[nid]

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        nodes_removed = []
        modified_graph = False
        for node in G.nodes(node_classes=QuantizeParameters):
            if issubclass(node.from_qtype.dtype, (np.floating, bfloat16)):
                if issubclass(node.to_qtype.dtype, (np.floating, bfloat16)):
                    LOG.warning(
                        'node %s quantizes from floating type to floating type and cannot directly be removed',
                        node.name)
                    continue
                if self.propagate_up(G, node, node.to_qtype):
                    modified_graph = True
                    nodes_removed.append(node)
                    G.remove_and_reconnect(node, edge_class=NNEdge)
                    if G.quantization:
                        del G.quantization[NodeId(node)]
                else:
                    LOG.warning('unable to remove quantize node %s', node.name)
            else:
                if self.propagate_down(G, node, node.from_qtype):
                    modified_graph = True
                    nodes_removed.append(node)
                    G.remove_and_reconnect(node, edge_class=NNEdge)
                    if G.quantization:
                        del G.quantization[NodeId(node)]
                else:
                    LOG.warning('unable to remove quantize node %s', node.name)

        if set_identity:
            self.set_identity(G)

        return modified_graph
Exemplo n.º 4
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = True
        while modified_graph:
            modified_graph = False
            for reshape in G.nodes(node_classes=(ReshapeParameters, )):
                if not reshape.has_transpose and reshape.shape.shape == reshape.old_shape.shape:
                    modified_graph = True
                    LOG.info('removing reshape that does nothing %s',
                             reshape.name)
                    G.remove_and_reconnect(reshape, edge_class=NNEdge)
                    nid = NodeId(reshape)
                    if G.quantization and nid in G.quantization:
                        del G.quantization[nid]
            res = None
            for reshape in G.nodes(node_classes=(ReshapeParameters, )):
                res = self.validate_reshape(G, reshape)
                if res:
                    LOG.info('unnecessary reshape found after %s',
                             reshape.name)
                    modified_graph = True
                    (reshape, candidates, out_shape) = res
                    for candidate in candidates:
                        LOG.info(
                            'removing unnecessary reshape or transpose %s',
                            candidate.name)
                        edges = G.out_edges(candidate.name)
                        G.remove(candidate)
                        nid = NodeId(candidate)
                        if G.quantization and nid in G.quantization:
                            del G.quantization[nid]
                        for edge in edges:
                            G.add_edge(
                                NNEdge(from_node=reshape,
                                       to_node=edge.to_node,
                                       to_idx=edge.to_idx))
                    reshape.shape = Dim.unnamed(out_shape)
                    break

        if set_identity:
            self.set_identity(G)

        return modified_graph
Exemplo n.º 5
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        nodes_to_remove = []
        for node in G.nodes(node_classes=CopyParameters):
            out_edges = G.out_edges(node)
            if len(out_edges) > 1:
                continue
            if (search_down(
                    G,
                    out_edges[0],
                (OutputParameters, InputParameters, ConstantInputParameters,
                 SplitParameters, ConcatParameters),
                    can_pass=(ReshapeParameters, NoOPParameters),
                    can_pass_fn=lambda G, node: isinstance(
                        node, TransposeParameters) and node.does_nothing,
                    follow_multi=True) and search_up(
                        G,
                        G.in_edges(node)[0],
                        (InputParameters, OutputParameters,
                         ConstantInputParameters, SplitParameters,
                         ConcatParameters),
                        can_pass=(ReshapeParameters, NoOPParameters),
                        can_pass_fn=lambda G, node: isinstance(
                            node, TransposeParameters) and node.does_nothing,
                        follow_multi=True)):
                continue
            nodes_to_remove.append(node)
        for node in nodes_to_remove:
            LOG.info("remove redundant copy %s", node.name)
            has_modified_graph = True
            G.remove_and_reconnect(node, edge_class=NNEdge)
            if G.quantization:
                nid = NodeId(node)
                if nid in G.quantization:
                    del G.quantization[nid]

        if set_identity:
            self.set_identity(G)
        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = False
        quantize_nodes = G.nodes(node_classes=QuantizeParameters)
        while quantize_nodes:
            node = quantize_nodes.pop(0)
            if G.quantization:
                qrec = G.quantization.get(NodeId(node))
                if not qrec:
                    continue
                if deepcopy(qrec.in_qs[0]) == qrec.out_qs[0]:
                    modified_graph = True
                    LOG.info('removing quantize node %s from %s to %s',
                             node.name, qrec.in_qs[0], qrec.out_qs[0])
                    G.remove_and_reconnect(node, edge_class=NNEdge)
                    del G.quantization[NodeId(node)]
                    continue

            next_node = self.get_single_quantize_edge(G, node)
            while next_node:
                LOG.info(
                    'removing quantize node %s and modifying node %s to output %s',
                    next_node.name, node.name, next_node.to_qtype)
                G.remove_and_reconnect(next_node, edge_class=NNEdge)
                node.to_qtype = next_node.to_qtype
                modified_graph = True
                if G.quantization:
                    this_rec = G.quantization[NodeId(node)]
                    next_rec = G.quantization[NodeId(next_node)]
                    this_rec.out_qs = next_rec.out_qs
                    del G.quantization[NodeId(next_node)]
                next_node = self.get_single_quantize_edge(G, node)

        if set_identity:
            self.set_identity(G)

        return modified_graph
Exemplo n.º 7
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = False
        candidates = set(G.nodes(node_classes=(ReshapeParameters, )))

        while candidates:
            node = candidates.pop()
            out_edges = G.out_edges(node.name)
            if len(out_edges) != 1 or not isinstance(
                    out_edges[0].to_node,
                    FcParameters) or out_edges[0].to_node.batch_size > 1:
                continue
            LOG.info('removing unnecessary reshape before linear %s',
                     node.name)
            G.remove_and_reconnect(node, edge_class=NNEdge)
            modified_graph = True
            nid = NodeId(node)
            if G.quantization and G.quantization.get(nid):
                del G.quantization[nid]
            modified_graph = True

        if set_identity:
            self.set_identity(G)

        return modified_graph
Exemplo n.º 8
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        to_quantize = []
        node_sets = self.find_sets(G)
        for node_set in node_sets:
            Symbol.set_default_control(SymbolStats())
            has_modified_graph = True
            in_edges, out_edges, internal_edges = group_edges(G, node_set)
            frag = GraphView()
            for node in node_set:
                frag.add_node(node)
            for edge in internal_edges:
                frag.add_edge(edge)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            in_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in in_edges
            ]
            out_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in out_edges
            ]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                node_edge_idx[0] for node_edge_idx in in_edges
                if isinstance(node_edge_idx[0], ConstantInputParameters)
            ]
            LOG.debug(
                "inputs coming from: %s",
                ",".join(f"{from_node.__repr__()}:{from_idx}"
                         for from_node, from_idx in in_edges))
            LOG.info("fusing nodes: %s into expr_%s",
                     ",".join(node.__repr__() for node in node_set),
                     self._expr_num)
            expr = ExpressionFusionParameters(
                G.unique_name(f"expr_{self._expr_num}"),
                subgraph=frag,
                qrecs=G.quantization,
                input_mapping=in_mapping,
                output_mapping=out_mapping,
                in_dims=in_dims,
                out_dims=out_dims,
                constant_inputs=constant_inputs)
            in_edge_mapping = list(in_edges.keys())
            out_edge_mapping = [[(edge.to_node, edge.to_idx)
                                 for edge in edge_set]
                                for edge_set in out_edges.values()]
            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=in_edge_mapping,
                edge_out_mapping=out_edge_mapping,
                edge_class=NNEdge)
            if G.quantization:
                qrecs = G.quantization
                in_qs = [
                    qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]]
                    for in_map in in_mapping
                ]
                out_qs = [
                    qrecs[NodeId(node)].out_qs[idx]
                    for node, idx in out_mapping
                ]
                stats = Symbol.CURRENT_CONTROL.stats
                func_col = expr.func_col
                for idx, qtype in enumerate(in_qs):
                    symbol = func_col.variables[func_col.input_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                for idx, qtype in enumerate(out_qs):
                    symbol = func_col.variables[func_col.output_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                G.quantization[NodeId(expr)] = QRec(in_qs=in_qs,
                                                    out_qs=out_qs,
                                                    expression=stats,
                                                    ktype='scaled')
                # delete any quantize parameters on outputs to allow the quantizer
                # to fuse them into the expression
                out_edges = G.out_edges(expr.name)
                for edge in out_edges:
                    if isinstance(edge.to_node, QuantizeParameters):
                        G.remove_and_reconnect(edge.to_node)
                        if NodeId(edge.to_node) in G.quantization:
                            del G.quantization[NodeId(edge.to_node)]
                to_quantize.append(expr)

            self._expr_num += 1

        if to_quantize:
            quantizer = UnifiedQuantizer.from_quantized_graph(G)
            G.quantization = quantizer.quantize(G, start_nodes=to_quantize)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph