def _CreateNewSession(self): """Updates self._sess with a new session.""" config = self._session_config if not config: config = py_utils.SessionConfig() sess = tf.Session(self._tf_master, graph=self._graph, config=config) try: sess.run(self._graph.get_operation_by_name("init_all_tables")) except KeyError: tf.logging.info("Could not find tables initializer in graph.") if self._device_type == "tpu": sess.run(self._graph.get_operation_by_name("tpu_init_op")) if self._checkpoint: self._saver.restore(sess, self._checkpoint) else: try: init_op = self._graph.get_operation_by_name("init_all_variables") sess.run(init_op) except KeyError: tf.logging.warning( "No checkpoint provided and the graph has no default " "variable_init op.") tf.logging.info("Created new predictor session.") self._sess = sess
def _GetSession(self, **kwargs): if py_utils.IsEagerMode(): raise ValueError('_GetSession is not supported in eager mode.') graph = kwargs.pop('graph', self._graph) return tf.Session(self._tf_master, graph=graph, config=py_utils.SessionConfig(**kwargs))
def _CreateNewSession(self): """Updates self._sess with a new session.""" sess = tf.Session(self._tf_master, graph=self._graph, config=py_utils.SessionConfig()) sess.run(self._graph.get_operation_by_name("init_all_tables")) if self._device_type == "tpu": sess.run(self._graph.get_operation_by_name("tpu_init_op")) if self._checkpoint: self._saver.restore(sess, self._checkpoint) tf.logging.info("Created new predictor session.") self._sess = sess
def _FreezeDefaults(graph, output_op_names): """Default initializes a graph and freezes it. Args: graph: tf.Graph. output_op_names: Names of output ops. Returns: Resulting tf.GraphDef. """ with tf.Session(graph=graph, config=py_utils.SessionConfig()) as sess: sess.run(graph.get_operation_by_name('init_all_variables')) return tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), output_op_names)
def _FreezeGraphFromCheckpoint(graph, saver, checkpoint, output_op_names): """Freezes a graph from a checkpoint. Args: graph: tf.Graph. saver: The tf.Saver to use for restoration. checkpoint: The checkpoint to restore. output_op_names: Names of output ops. Returns: Resulting tf.GraphDef. """ sess = tf.Session(graph=graph, config=py_utils.SessionConfig()) saver.restore(sess, checkpoint) return tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), output_op_names)
def _GetSession(self, **kwargs): graph = kwargs.pop('graph', self._graph) return tf.Session(self._tf_master, graph=graph, config=py_utils.SessionConfig(**kwargs))
def _GetSession(self, **kwargs): return tf.Session(self._tf_master, graph=self._graph, config=py_utils.SessionConfig(**kwargs))