示例#1
0
 def _TestGraph(self):
     params = model_registry.GetParams('test.LinearModelParams', 'Test')
     inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
         params)
     graph = tf.Graph()
     with graph.as_default():
         tf.import_graph_def(inference_graph.graph_def, name='')
     return graph, inference_graph
示例#2
0
 def testExportFreezeDefault(self):
     """Test exporting frozen graph."""
     params = model_registry.GetParams('test.LinearModelParams', 'Test')
     inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
         params, freeze_defaults=True, subgraph_filter=['default'])
     self.assertIn('default', inference_graph.subgraphs)
     # Test graphs are well-formed and importable.
     with tf.Graph().as_default():
         tf.import_graph_def(inference_graph.graph_def)
示例#3
0
    def __init__(self,
                 inference_graph,
                 subgraph_name=None,
                 checkpoint=None,
                 device_type="gpu",
                 tf_master="",
                 session_config=None,
                 clear_device_placement=False):
        assert device_type in ["cpu", "gpu", "tpu"]
        subgraph_name = subgraph_name or "default"
        if isinstance(inference_graph, six.string_types):
            tf.logging.info("Reading inference graph from %s.",
                            inference_graph)
            inference_graph = LoadInferenceGraph(inference_graph,
                                                 clear_device_placement)
        self._inference_graph = inference_graph
        self._checkpoint = checkpoint
        self._device_type = device_type
        self._tf_master = tf_master
        self._session_config = session_config

        self._graph = tf.Graph()
        with self._graph.as_default():
            tf.logging.info(
                "Loading inference graph for prediction subgraph_name={}.".
                format(subgraph_name))
            self._saver = tf.train.Saver(saver_def=inference_graph.saver_def)
            with tf.device("/%s:0" %
                           "cpu" if device_type == "tpu" else device_type):
                tf.import_graph_def(inference_graph.graph_def, name="")
            if device_type == "tpu":
                # If no tpu init op exists, create it here.
                try:
                    self._graph.get_operation_by_name("tpu_init_op")
                except KeyError:
                    tf.group(tf.tpu.initialize_system(), name="tpu_init_op")

            self._graph.finalize()

        if inference_graph.subgraphs:
            if subgraph_name not in inference_graph.subgraphs:
                raise ValueError(
                    "Subgraph %s not defined. Valid subgraphs: %s" %
                    (subgraph_name, list(inference_graph.subgraphs.keys())))
            subgraph = inference_graph.subgraphs[subgraph_name]
            self._fetches = subgraph.fetches
            self._feeds = subgraph.feeds
        else:
            self._fetches = inference_graph.fetches
            self._feeds = inference_graph.feeds

        # Lock for creating new sessions.
        self._sess_lock = threading.Lock()
        self._cur_sess_id = 0
        self._CreateNewSession()
示例#4
0
    def _load_graph_from_inference_graph(self, inference_graph):
        """Returns a tf.Graph() constructed from `inference_graph`.

    Args:
      inference_graph: An InferenceGraph proto from which a graph_def is loaded
        from.

    Returns:
      A loaded tf.Graph().
    """
        graph = tf.Graph()
        with graph.as_default():
            with tf.device("/%s:0" % "cpu" if self._device_type ==
                           "tpu" else self._device_type):
                tf.import_graph_def(inference_graph.graph_def, name="")
        return graph
示例#5
0
    def __init__(self,
                 inference_graph,
                 subgraph_name=None,
                 checkpoint=None,
                 device_type="gpu",
                 tf_master=""):
        assert device_type in ["cpu", "gpu", "tpu"]
        subgraph_name = subgraph_name or "default"
        if isinstance(inference_graph, six.string_types):
            tf.logging.info("Reading inference graph from %s.",
                            inference_graph)
            inference_graph = LoadInferenceGraph(inference_graph)
        self._inference_graph = inference_graph
        self._checkpoint = checkpoint
        self._device_type = device_type
        self._tf_master = tf_master

        self._graph = tf.Graph()
        with self._graph.as_default():
            tf.logging.info("Loading inference graph for prediction.")
            self._saver = tf.train.Saver(saver_def=inference_graph.saver_def)
            with tf.device("/%s:0" %
                           "cpu" if device_type == "tpu" else device_type):
                tf.import_graph_def(inference_graph.graph_def, name="")
            self._graph.finalize()

        if inference_graph.subgraphs:
            if subgraph_name not in inference_graph.subgraphs:
                raise ValueError(
                    "Subgraph %s not defined. Valid subgraphs: %s" %
                    (subgraph_name, list(inference_graph.subgraphs.keys())))
            subgraph = inference_graph.subgraphs[subgraph_name]
            self._fetches = subgraph.fetches
            self._feeds = subgraph.feeds
        else:
            self._fetches = inference_graph.fetches
            self._feeds = inference_graph.feeds

        # Lock for creating new sessions.
        self._sess_lock = threading.Lock()
        self._cur_sess_id = 0
        self._CreateNewSession()
示例#6
0
    def __init__(self,
                 inference_graph,
                 subgraph_name=None,
                 checkpoint=None,
                 device_type="gpu",
                 tf_master="",
                 session_config=None,
                 clear_device_placement=False):
        """Constructor.

    Args:
      inference_graph: A saved InferenceGraph proto.
      subgraph_name: The subgraph to use for prediction.
      checkpoint: An optional checkpoint to load.
      device_type: Device type string. Either "cpu", "gpu", or "tpu".
      tf_master: The tf_master.
      session_config: A tf.SessionConfig to use. By default
        py_utils.SessionConfig() is used.
      clear_device_placement: If set, clears device field of loaded inference
        graph.
    """
        assert device_type in ["cpu", "gpu", "tpu"]
        subgraph_name = subgraph_name or "default"
        if isinstance(inference_graph, str):
            tf.logging.info("Reading inference graph from %s.",
                            inference_graph)
            inference_graph = LoadInferenceGraph(inference_graph,
                                                 clear_device_placement)
        self._inference_graph = inference_graph
        self._checkpoint = checkpoint
        self._device_type = device_type
        self._tf_master = tf_master
        self._session_config = session_config

        self._graph = tf.Graph()
        with self._graph.as_default():
            tf.logging.info(
                "Loading inference graph for prediction subgraph_name={}.".
                format(subgraph_name))
            self._saver = tf.train.Saver(saver_def=inference_graph.saver_def)
            with tf.device("/%s:0" %
                           "cpu" if device_type == "tpu" else device_type):
                tf.import_graph_def(inference_graph.graph_def, name="")
            if device_type == "tpu":
                # If no tpu init op exists, create it here.
                try:
                    self._graph.get_operation_by_name("tpu_init_op")
                except KeyError:
                    tf.group(tf.tpu.initialize_system(), name="tpu_init_op")

            self._graph.finalize()

        if inference_graph.subgraphs:
            if subgraph_name not in inference_graph.subgraphs:
                raise ValueError(
                    "Subgraph %s not defined. Valid subgraphs: %s" %
                    (subgraph_name, list(inference_graph.subgraphs.keys())))
            subgraph = inference_graph.subgraphs[subgraph_name]
            self._fetches = subgraph.fetches
            self._feeds = subgraph.feeds
        else:
            self._fetches = inference_graph.fetches
            self._feeds = inference_graph.feeds

        # Lock for creating new sessions.
        self._sess_lock = threading.Lock()
        self._cur_sess_id = 0
        self._CreateNewSession()
示例#7
0
    def __init__(self,
                 inference_graph,
                 subgraph_name=None,
                 checkpoint=None,
                 device_type="gpu",
                 tf_master="",
                 session_config=None,
                 clear_device_placement=False,
                 load_graph_def_from_inference_graph=True):
        """Constructor.

    Args:
      inference_graph: A saved InferenceGraph proto.
      subgraph_name: The default subgraph to use for Run().
      checkpoint: An optional checkpoint to load.
      device_type: Device type string. Either "cpu", "gpu", or "tpu".
      tf_master: The tf_master.
      session_config: A tf.SessionConfig to use. By default
        py_utils.SessionConfig() is used.
      clear_device_placement: If set, clears device field of loaded inference
        graph.
      load_graph_def_from_inference_graph: Whether to load a graph def.
        If False, assumes the names in the inference graph correspond to tensors
        in the current default graph.
    """
        assert device_type in ["cpu", "gpu", "tpu"]
        subgraph_name = subgraph_name or "default"
        if isinstance(inference_graph, str):
            tf.logging.info("Reading inference graph from %s.",
                            inference_graph)
            inference_graph = LoadInferenceGraph(inference_graph,
                                                 clear_device_placement)
        self._inference_graph = inference_graph
        self._default_subgraph_name = subgraph_name
        self._checkpoint = checkpoint
        self._device_type = device_type
        self._tf_master = tf_master
        self._session_config = session_config

        if load_graph_def_from_inference_graph:
            self._graph = tf.Graph()
            with self._graph.as_default():
                tf.logging.info(
                    "Loading inference graph for prediction subgraph_name={}.".
                    format(subgraph_name))
                with tf.device("/%s:0" %
                               "cpu" if device_type == "tpu" else device_type):
                    tf.import_graph_def(inference_graph.graph_def, name="")
        else:
            self._graph = tf.get_default_graph()

        if device_type == "tpu":
            # If no tpu init op exists, create it here.
            try:
                self._graph.get_operation_by_name("tpu_init_op")
            except KeyError:
                with self._graph.as_default():
                    tf.group(tf.tpu.initialize_system(), name="tpu_init_op")

        self._graph.finalize()

        if inference_graph.subgraphs:
            if subgraph_name not in inference_graph.subgraphs:
                raise ValueError(
                    f"Subgraph {subgraph_name} not defined. Valid subgraphs: "
                    f"{self.subgraphs}")
            subgraph = inference_graph.subgraphs[subgraph_name]
            self._fetches = subgraph.fetches
            self._feeds = subgraph.feeds
        else:
            if "fetches" not in inference_graph or "feeds" not in inference_graph:
                raise ValueError(
                    "Graph does not contain feeds or fetches. Inference "
                    "graph is probably empty!")
            self._fetches = inference_graph.fetches
            self._feeds = inference_graph.feeds

        # Lock for creating new sessions.
        self._sess_lock = threading.Lock()
        self._cur_sess_id = 0
        self._create_new_session()