Ejemplo n.º 1
0
 def run(self, v, run_option=None, run_statistic=None):
     current_graph().unfinalize()
     try:
         if self._session is None:
             self._restart_client()
             self._session = SimpleSession(self._hooks)
         return self._session.run(v, run_option, run_statistic)
     except (PsError) as e:
         print(
             'An error was raised. This may be due to a preemption in '
             'a connected worker or parameter server. The current '
             'session will be closed and a new session will be '
             'created. Error: %s', str(e))
         time.sleep(5)
         self._session = None
     except OutOfRange:
         self._finish = True
Ejemplo n.º 2
0
 def __init__(self, outs, g=None, weak_variable=False, slow_start=1000):
     if g is None:
         g = current_graph()
     self._g = g
     self._outs = _flatten_outs(outs)
     self._real_outs = []
     self._weak_variable = weak_variable
     self._slow_start = slow_start
     self._step = False
     self._build()
     self._run_time = 0
Ejemplo n.º 3
0
 def __init__(self):
     self._ckpt_dir = xdl.get_ckpt_dir()
     if self._ckpt_dir is None:
         raise ValueError('must specify ckpt_dir arg in cmdline')
     self._graph_def = _graphdef_to_pb(current_graph()._graph_def)
Ejemplo n.º 4
0
 def enqueue_start(self):
     enqueue = xdl.enqueue_op(tensors=self._enqueue_tensors,
                              names=";".join(self._names))
     current_graph().execute_loop(enqueue)
     xdl.Graph._current_graph.pop()
Ejemplo n.º 5
0
 def __init__(self, ckpt_dir=None, tf_graph_name=None):
     self._ckpt_dir = ckpt_dir
     if self._ckpt_dir is None:
       self._ckpt_dir = get_ckpt_dir()
     self._graph_def = _graphdef_to_pb(current_graph()._graph_def)
     self._tf_graph_name = tf_graph_name
Ejemplo n.º 6
0
 def __init__(self, ckpt_dir=None):
     self._ckpt_dir = ckpt_dir
     self._graph_def = _graphdef_to_pb(current_graph()._graph_def)
Ejemplo n.º 7
0
 def feed_output_list(self):
   ops = current_graph().nodes()
   for op in ops.values():
     for i in range(len(op.inputs)):
       self._output_list[op.inputs[i]].append((op, i))