Esempio n. 1
0
class PreOptimization(object):
    def __init__(self, model, inputs, outputs):
        self.output_node_names = list(
            set([output.split(":")[0] for output in outputs]))
        self.input_graph = get_graph_def(model, self.output_node_names)
        if 'MakeIterator' in [node.op for node in self.input_graph.node]:
            self.output_node_names.append('MakeIterator')

        self.analyzer = GraphAnalyzer()
        self.analyzer.graph = self.input_graph
        self.analyzer.parse_graph()
        self.input_node_names = inputs
        self.logger = logging.getLogger()
        self._tmp_graph_def = None
        self._excluded_node_names = []
        if not self.input_node_names or not self.output_node_names:
            self.input_node_names, self.output_node_names = self.analyzer.get_graph_input_output(
            )

    def get_excluded_node_names(self):
        """Get the excluded node name

        Returns:
            string list: the excluded ops' name
        """
        return self._excluded_node_names

    @dump_elapsed_time("Pass Pre Optimization")
    def get_optimized_graphdef(self):
        """Executed the non-precision dependant graph optimization.
        The input graph will be optimized with following passes:
        1. Remove the training nodes like Identity Op.
        2. Split the shared nodes like weights node for multi-Conv2d.
        3. Fold Constant Nodes as less as possible.
        4. Fuse the Mul node into the previous Conv2D/MatMul if possible.
        5. Strip the useless nodes.
        6. Do the Common sequence elimation optimization on the graph.
        7. Fold the BN node into the previous Conv2D if possible.

        Returns:
            [graphdef]: the optimized graphdef object.
        """
        self.logger.debug("Start to pre optimize input model...")

        self._tmp_graph_def = ConvertLayoutOptimizer(
            self.input_graph, self.output_node_names).do_transformation()

        self._tmp_graph_def = RemoveTrainingNodesOptimizer(
            self._tmp_graph_def,
            protected_nodes=self.output_node_names).do_transformation()

        self._tmp_graph_def = SplitSharedInputOptimizer(
            self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = GraphFoldConstantOptimizer(
            self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = FuseColumnWiseMulOptimizer(
            self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = StripUnusedNodesOptimizer(
            self._tmp_graph_def, self.input_node_names,
            self.output_node_names).do_transformation()

        self._tmp_graph_def = GraphCseOptimizer(
            self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = FoldBatchNormNodesOptimizer(
            self._tmp_graph_def).do_transformation()

        #TODO we should handle all control ops elegantly not bypass it.
        self._tmp_graph_def, excluded_node_names = UpdateEnterOptimizer(
            self._tmp_graph_def).do_transformation()
        self._excluded_node_names.extend(excluded_node_names)
        self._tmp_graph_def.library.CopyFrom(self.input_graph.library)

        return self._tmp_graph_def

    def get_matched_nodes(self, patterns):
        """Searche the matched nodes with the specified patterns

        Args:
            patterns ([string list]): The patterns should be illustrated as below.
                [['MatMul'], ("BiasAdd"), ("Relu",)]
        Returns:
            [string list]: It will return the list that contains the matched nodes name
                and pattern. ['matched_node_a_name', 'matched_node_a_name',['MatMul','BiasAdd']]
        """
        self.analyzer.graph = self._tmp_graph_def
        self.analyzer.parse_graph()
        res = []

        for sub_pattern in patterns:
            res.extend(self.analyzer.query_fusion_pattern_nodes(sub_pattern))

        return res

    def has_positive_input(self, node_name):
        """Check the specified node has the positive input or not.

        Args:
            node_name ([string]): node's name

        Returns:
            [bool]: True if the node has the positive input data,
                    False if the node has the negative input data.
        """
        return self.analyzer.has_positive_input(node_name)
Esempio n. 2
0
class PreOptimization():
    def __init__(self, model, optimization):
        self.model = model
        self.optimization = optimization
        self.output_node_names = model.output_node_names
        self.input_node_names = model.input_node_names
        if model.iter_op is not None:
            self.output_node_names.append('MakeIterator')

        self.analyzer = GraphAnalyzer()
        self.analyzer.graph = model.graph_def
        self.analyzer.parse_graph()
        self.logger = logging.getLogger()
        self._tmp_graph_def = None
        self._excluded_node_names = []

    def get_excluded_node_names(self):
        """Get the excluded node name

        Returns:
            string list: the excluded ops' name
        """
        return self._excluded_node_names

    @dump_elapsed_time("Pass Pre Optimization")
    def get_optimized_model(self):
        """Executed the non-precision dependant graph optimization.
        The input graph will be optimized with following passes:
        1. Remove the training nodes like Identity Op.
        2. Split the shared nodes like weights node for multi-Conv2d.
        3. Fold Constant Nodes as less as possible.
        4. Fuse the Mul node into the previous Conv2D/MatMul if possible.
        5. Strip the useless nodes.
        6. Do the Common sequence elimation optimization on the graph.
        7. Fold the BN node into the previous Conv2D if possible.

        Returns:
            [graphdef]: the optimized graphdef object.
        """
        self.logger.debug("Start to pre optimize input model...")

        origin_model = TensorflowModel(self.model._model,
                                       self.model.framework_specific_info,
                                       **self.model.kwargs)

        self._tmp_graph_def = ConvertLayoutOptimizer(
            self.model.graph_def, self.output_node_names).do_transformation()

        self._tmp_graph_def = GrapplerOptimizer(
            self._tmp_graph_def, self.output_node_names, self.optimization).do_transformation()

        self._tmp_graph_def = RemoveTrainingNodesOptimizer(
            self._tmp_graph_def, protected_nodes=self.output_node_names).do_transformation()

        self._tmp_graph_def = SplitSharedInputOptimizer(self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = GraphFoldConstantOptimizer(self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = FuseColumnWiseMulOptimizer(self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = StripUnusedNodesOptimizer(self._tmp_graph_def,
            self.input_node_names, self.output_node_names).do_transformation()

        self._tmp_graph_def = FuseGeluOptimizer(self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = GraphCseOptimizer(self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = FoldBatchNormNodesOptimizer(
            self._tmp_graph_def).do_transformation()

        #TODO we should handle all control ops elegantly not bypass it.
        self._tmp_graph_def, excluded_node_names = UpdateEnterOptimizer(
            self._tmp_graph_def).do_transformation()

        #TODO we need to remove below optimizer once the TF enabled the single
        # matmul op quantization
        self._tmp_graph_def = InjectDummyBiasAddOptimizer(
            self._tmp_graph_def).do_transformation()

        self._tmp_graph_def = FuseTransposeReshapeOptimizer(
            self._tmp_graph_def).do_transformation()

        self._excluded_node_names.extend(excluded_node_names)
        self._tmp_graph_def.library.CopyFrom(self.model.graph_def.library)

        origin_model.graph_def = self._tmp_graph_def

        return origin_model

    def get_matched_nodes(self, patterns):
        """Searche the matched nodes with the specified patterns

        Args:
            patterns ([string list]): The patterns should be illustrated as below.
                [['MatMul'], ("BiasAdd"), ("Relu",)]
        Returns:
            [string list]: It will return the list that contains the matched nodes name
                and pattern. ['matched_node_a_name', 'matched_node_a_name',['MatMul','BiasAdd']]
        """
        self.analyzer.graph = self._tmp_graph_def
        self.analyzer.parse_graph()
        res = []

        for sub_pattern in patterns:
            res.extend([i for i in self.analyzer.query_fusion_pattern_nodes(
                sub_pattern) if i not in res])
        return res

    def has_positive_input(self, node_name):
        """Check the specified node has the positive input or not.

        Args:
            node_name ([string]): node's name

        Returns:
            [bool]: True if the node has the positive input data,
                    False if the node has the negative input data.
        """
        return self.analyzer.has_positive_input(node_name)