Esempio n. 1
0
    def _freeze_keras_saved_model(self, saved_model_dir):
        """Freezes the model and returns the frozen GraphDef.

        Frozen here means that all variables are converted to placeholders.

        Args:
          saved_model_dir: Directory with the Keras SavedModel export.

        Returns:
          Frozen GraphDef for the model.
        """
        temp_dir = tempfile.mkdtemp("tflite-transfer-convert")
        graph_def_file_name = os.path.join(temp_dir, "frozen.pb")
        output_names = [
            utils.tensor_to_op_name(output.name)
            for output in self._eval_signature.outputs.values()
        ]

        freeze_graph.freeze_graph(
            input_graph=None,
            input_saver=False,
            input_binary=True,
            input_checkpoint=None,
            output_node_names=",".join(output_names),
            restore_op_name=None,
            filename_tensor_name=None,
            output_graph=graph_def_file_name,
            clear_devices=True,
            initializer_nodes="",
            input_saved_model_dir=saved_model_dir,
            saved_model_tags="eval",
        )

        const_graph_def = tfv1.GraphDef()
        with open(graph_def_file_name, "rb") as graph_def_file:
            const_graph_def.ParseFromString(graph_def_file.read())

        # Convert constants produced from trainable variables to placeholders.
        # Note: eval model might have other variables that should not be trainable,
        # they are kept as constants. Only variables that are present in serve
        # model are converted.
        graph_def = utils.convert_constants_to_placeholders(
            const_graph_def, self._variable_names)

        shutil.rmtree(temp_dir)
        return graph_def
Esempio n. 2
0
    def _frozen_graph_def(self):
        """Freezes the model and returns the frozen GraphDef.

        Frozen here means that all variables are converted to placeholders.

        Returns:
          Frozen GraphDef for the model.
        """
        temp_dir = tempfile.mkdtemp("tflite-transfer-convert")
        graph_def_file_name = os.path.join(temp_dir, "frozen.pb")
        output_name = utils.tensor_to_op_name(
            next(self._signature.outputs.values().__iter__()).name)

        freeze_graph.freeze_graph(
            input_graph=None,
            input_saver=False,
            input_binary=True,
            input_checkpoint=None,
            output_node_names=output_name,
            restore_op_name=None,
            filename_tensor_name=None,
            output_graph=graph_def_file_name,
            clear_devices=True,
            initializer_nodes="",
            input_saved_model_dir=self.model_dir,
            saved_model_tags=self.tag,
        )

        const_graph_def = tfv1.GraphDef()
        with open(graph_def_file_name, "rb") as graph_def_file:
            const_graph_def.ParseFromString(graph_def_file.read())

        # Convert constants produced from variables to placeholders.
        graph_def = utils.convert_constants_to_placeholders(
            const_graph_def, self._variable_names)

        shutil.rmtree(temp_dir)
        return graph_def