def _TestGraph(self): params = model_registry.GetParams('test.LinearModelParams', 'Test') inference_graph = inference_graph_exporter.InferenceGraphExporter.Export( params) graph = tf.Graph() with graph.as_default(): tf.import_graph_def(inference_graph.graph_def, name='') return graph, inference_graph
def testExportFreezeDefault(self): """Test exporting frozen graph.""" params = model_registry.GetParams('test.LinearModelParams', 'Test') inference_graph = inference_graph_exporter.InferenceGraphExporter.Export( params, freeze_defaults=True, subgraph_filter=['default']) self.assertIn('default', inference_graph.subgraphs) # Test graphs are well-formed and importable. with tf.Graph().as_default(): tf.import_graph_def(inference_graph.graph_def)
def __init__(self, inference_graph, subgraph_name=None, checkpoint=None, device_type="gpu", tf_master="", session_config=None, clear_device_placement=False): assert device_type in ["cpu", "gpu", "tpu"] subgraph_name = subgraph_name or "default" if isinstance(inference_graph, six.string_types): tf.logging.info("Reading inference graph from %s.", inference_graph) inference_graph = LoadInferenceGraph(inference_graph, clear_device_placement) self._inference_graph = inference_graph self._checkpoint = checkpoint self._device_type = device_type self._tf_master = tf_master self._session_config = session_config self._graph = tf.Graph() with self._graph.as_default(): tf.logging.info( "Loading inference graph for prediction subgraph_name={}.". format(subgraph_name)) self._saver = tf.train.Saver(saver_def=inference_graph.saver_def) with tf.device("/%s:0" % "cpu" if device_type == "tpu" else device_type): tf.import_graph_def(inference_graph.graph_def, name="") if device_type == "tpu": # If no tpu init op exists, create it here. try: self._graph.get_operation_by_name("tpu_init_op") except KeyError: tf.group(tf.tpu.initialize_system(), name="tpu_init_op") self._graph.finalize() if inference_graph.subgraphs: if subgraph_name not in inference_graph.subgraphs: raise ValueError( "Subgraph %s not defined. Valid subgraphs: %s" % (subgraph_name, list(inference_graph.subgraphs.keys()))) subgraph = inference_graph.subgraphs[subgraph_name] self._fetches = subgraph.fetches self._feeds = subgraph.feeds else: self._fetches = inference_graph.fetches self._feeds = inference_graph.feeds # Lock for creating new sessions. self._sess_lock = threading.Lock() self._cur_sess_id = 0 self._CreateNewSession()
def _load_graph_from_inference_graph(self, inference_graph): """Returns a tf.Graph() constructed from `inference_graph`. Args: inference_graph: An InferenceGraph proto from which a graph_def is loaded from. Returns: A loaded tf.Graph(). """ graph = tf.Graph() with graph.as_default(): with tf.device("/%s:0" % "cpu" if self._device_type == "tpu" else self._device_type): tf.import_graph_def(inference_graph.graph_def, name="") return graph
def __init__(self, inference_graph, subgraph_name=None, checkpoint=None, device_type="gpu", tf_master=""): assert device_type in ["cpu", "gpu", "tpu"] subgraph_name = subgraph_name or "default" if isinstance(inference_graph, six.string_types): tf.logging.info("Reading inference graph from %s.", inference_graph) inference_graph = LoadInferenceGraph(inference_graph) self._inference_graph = inference_graph self._checkpoint = checkpoint self._device_type = device_type self._tf_master = tf_master self._graph = tf.Graph() with self._graph.as_default(): tf.logging.info("Loading inference graph for prediction.") self._saver = tf.train.Saver(saver_def=inference_graph.saver_def) with tf.device("/%s:0" % "cpu" if device_type == "tpu" else device_type): tf.import_graph_def(inference_graph.graph_def, name="") self._graph.finalize() if inference_graph.subgraphs: if subgraph_name not in inference_graph.subgraphs: raise ValueError( "Subgraph %s not defined. Valid subgraphs: %s" % (subgraph_name, list(inference_graph.subgraphs.keys()))) subgraph = inference_graph.subgraphs[subgraph_name] self._fetches = subgraph.fetches self._feeds = subgraph.feeds else: self._fetches = inference_graph.fetches self._feeds = inference_graph.feeds # Lock for creating new sessions. self._sess_lock = threading.Lock() self._cur_sess_id = 0 self._CreateNewSession()
def __init__(self, inference_graph, subgraph_name=None, checkpoint=None, device_type="gpu", tf_master="", session_config=None, clear_device_placement=False): """Constructor. Args: inference_graph: A saved InferenceGraph proto. subgraph_name: The subgraph to use for prediction. checkpoint: An optional checkpoint to load. device_type: Device type string. Either "cpu", "gpu", or "tpu". tf_master: The tf_master. session_config: A tf.SessionConfig to use. By default py_utils.SessionConfig() is used. clear_device_placement: If set, clears device field of loaded inference graph. """ assert device_type in ["cpu", "gpu", "tpu"] subgraph_name = subgraph_name or "default" if isinstance(inference_graph, str): tf.logging.info("Reading inference graph from %s.", inference_graph) inference_graph = LoadInferenceGraph(inference_graph, clear_device_placement) self._inference_graph = inference_graph self._checkpoint = checkpoint self._device_type = device_type self._tf_master = tf_master self._session_config = session_config self._graph = tf.Graph() with self._graph.as_default(): tf.logging.info( "Loading inference graph for prediction subgraph_name={}.". format(subgraph_name)) self._saver = tf.train.Saver(saver_def=inference_graph.saver_def) with tf.device("/%s:0" % "cpu" if device_type == "tpu" else device_type): tf.import_graph_def(inference_graph.graph_def, name="") if device_type == "tpu": # If no tpu init op exists, create it here. try: self._graph.get_operation_by_name("tpu_init_op") except KeyError: tf.group(tf.tpu.initialize_system(), name="tpu_init_op") self._graph.finalize() if inference_graph.subgraphs: if subgraph_name not in inference_graph.subgraphs: raise ValueError( "Subgraph %s not defined. Valid subgraphs: %s" % (subgraph_name, list(inference_graph.subgraphs.keys()))) subgraph = inference_graph.subgraphs[subgraph_name] self._fetches = subgraph.fetches self._feeds = subgraph.feeds else: self._fetches = inference_graph.fetches self._feeds = inference_graph.feeds # Lock for creating new sessions. self._sess_lock = threading.Lock() self._cur_sess_id = 0 self._CreateNewSession()
def __init__(self, inference_graph, subgraph_name=None, checkpoint=None, device_type="gpu", tf_master="", session_config=None, clear_device_placement=False, load_graph_def_from_inference_graph=True): """Constructor. Args: inference_graph: A saved InferenceGraph proto. subgraph_name: The default subgraph to use for Run(). checkpoint: An optional checkpoint to load. device_type: Device type string. Either "cpu", "gpu", or "tpu". tf_master: The tf_master. session_config: A tf.SessionConfig to use. By default py_utils.SessionConfig() is used. clear_device_placement: If set, clears device field of loaded inference graph. load_graph_def_from_inference_graph: Whether to load a graph def. If False, assumes the names in the inference graph correspond to tensors in the current default graph. """ assert device_type in ["cpu", "gpu", "tpu"] subgraph_name = subgraph_name or "default" if isinstance(inference_graph, str): tf.logging.info("Reading inference graph from %s.", inference_graph) inference_graph = LoadInferenceGraph(inference_graph, clear_device_placement) self._inference_graph = inference_graph self._default_subgraph_name = subgraph_name self._checkpoint = checkpoint self._device_type = device_type self._tf_master = tf_master self._session_config = session_config if load_graph_def_from_inference_graph: self._graph = tf.Graph() with self._graph.as_default(): tf.logging.info( "Loading inference graph for prediction subgraph_name={}.". format(subgraph_name)) with tf.device("/%s:0" % "cpu" if device_type == "tpu" else device_type): tf.import_graph_def(inference_graph.graph_def, name="") else: self._graph = tf.get_default_graph() if device_type == "tpu": # If no tpu init op exists, create it here. try: self._graph.get_operation_by_name("tpu_init_op") except KeyError: with self._graph.as_default(): tf.group(tf.tpu.initialize_system(), name="tpu_init_op") self._graph.finalize() if inference_graph.subgraphs: if subgraph_name not in inference_graph.subgraphs: raise ValueError( f"Subgraph {subgraph_name} not defined. Valid subgraphs: " f"{self.subgraphs}") subgraph = inference_graph.subgraphs[subgraph_name] self._fetches = subgraph.fetches self._feeds = subgraph.feeds else: if "fetches" not in inference_graph or "feeds" not in inference_graph: raise ValueError( "Graph does not contain feeds or fetches. Inference " "graph is probably empty!") self._fetches = inference_graph.fetches self._feeds = inference_graph.feeds # Lock for creating new sessions. self._sess_lock = threading.Lock() self._cur_sess_id = 0 self._create_new_session()