Esempio n. 1
0
  def _CreateNewSession(self):
    """Updates self._sess with a new session."""
    config = self._session_config
    if not config:
      config = py_utils.SessionConfig()
    sess = tf.Session(self._tf_master, graph=self._graph, config=config)

    try:
      sess.run(self._graph.get_operation_by_name("init_all_tables"))
    except KeyError:
      tf.logging.info("Could not find tables initializer in graph.")
    if self._device_type == "tpu":
      sess.run(self._graph.get_operation_by_name("tpu_init_op"))
    if self._checkpoint:
      self._saver.restore(sess, self._checkpoint)
    else:
      try:
        init_op = self._graph.get_operation_by_name("init_all_variables")
        sess.run(init_op)
      except KeyError:
        tf.logging.warning(
            "No checkpoint provided and the graph has no default "
            "variable_init op.")
    tf.logging.info("Created new predictor session.")
    self._sess = sess
Esempio n. 2
0
 def _GetSession(self, **kwargs):
     if py_utils.IsEagerMode():
         raise ValueError('_GetSession is not supported in eager mode.')
     graph = kwargs.pop('graph', self._graph)
     return tf.Session(self._tf_master,
                       graph=graph,
                       config=py_utils.SessionConfig(**kwargs))
Esempio n. 3
0
 def _CreateNewSession(self):
     """Updates self._sess with a new session."""
     sess = tf.Session(self._tf_master,
                       graph=self._graph,
                       config=py_utils.SessionConfig())
     sess.run(self._graph.get_operation_by_name("init_all_tables"))
     if self._device_type == "tpu":
         sess.run(self._graph.get_operation_by_name("tpu_init_op"))
     if self._checkpoint:
         self._saver.restore(sess, self._checkpoint)
     tf.logging.info("Created new predictor session.")
     self._sess = sess
def _FreezeDefaults(graph, output_op_names):
    """Default initializes a graph and freezes it.

  Args:
    graph: tf.Graph.
    output_op_names: Names of output ops.

  Returns:
    Resulting tf.GraphDef.
  """
    with tf.Session(graph=graph, config=py_utils.SessionConfig()) as sess:
        sess.run(graph.get_operation_by_name('init_all_variables'))
        return tf.graph_util.convert_variables_to_constants(
            sess, graph.as_graph_def(), output_op_names)
def _FreezeGraphFromCheckpoint(graph, saver, checkpoint, output_op_names):
    """Freezes a graph from a checkpoint.

  Args:
    graph: tf.Graph.
    saver: The tf.Saver to use for restoration.
    checkpoint: The checkpoint to restore.
    output_op_names: Names of output ops.

  Returns:
    Resulting tf.GraphDef.
  """
    sess = tf.Session(graph=graph, config=py_utils.SessionConfig())
    saver.restore(sess, checkpoint)
    return tf.graph_util.convert_variables_to_constants(
        sess, graph.as_graph_def(), output_op_names)
Esempio n. 6
0
 def _GetSession(self, **kwargs):
     graph = kwargs.pop('graph', self._graph)
     return tf.Session(self._tf_master,
                       graph=graph,
                       config=py_utils.SessionConfig(**kwargs))
Esempio n. 7
0
 def _GetSession(self, **kwargs):
     return tf.Session(self._tf_master,
                       graph=self._graph,
                       config=py_utils.SessionConfig(**kwargs))