def setup(self): """ Setup the trainer and be ready for the main loop. """ self._setup() # subclass will setup the graph self.monitors = Monitors(self.monitors) self.register_callback(self.monitors) describe_model() # some final operations that might modify the graph logger.info("Setup callbacks graph ...") self._callbacks = Callbacks(self._callbacks) self._callbacks.setup_graph(weakref.proxy(self)) # create session logger.info("Finalize the graph, create the session ...") self.sess = self.config.session_creator.create_session() self._monitored_sess = tf.train.MonitoredSession( session_creator=ReuseSessionCreator(self.sess), hooks=None) # init session init_op = tf.global_variables_initializer() self.sess.run(init_op) logger.info("Graph variables initialized.") self.config.session_init.init(self.sess) self.sess.graph.finalize() hooks = self._callbacks.get_hooks() self.hooked_sess = HookedSession(self.sess, hooks)
def setup(self): """ Setup the trainer and be ready for the main loop. """ self._setup() # subclass will setup the graph describe_model() # some final operations that might modify the graph logger.info("Setup monitors ...") self.monitors = Monitors(self.monitors) self.monitors.setup(weakref.proxy(self)) logger.info("Setup callbacks graph ...") self._callbacks = Callbacks(self._callbacks) self._callbacks.setup_graph(weakref.proxy(self)) self.config.session_init._setup_graph() def after_init(scaffold, sess): logger.info("Graph variables initialized.") self.config.session_init._run_init(sess) scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer(), init_fn=after_init) logger.info("Finalize the graph, create the session ...") self.monitored_sess = tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator( scaffold=scaffold, config=self.config.session_config), hooks=None) self.sess = self.monitored_sess._tf_sess( ) # expose the underlying session also hooks = self._callbacks.get_hooks() self.hooked_sess = HookedSession(self.sess, hooks)
def _before_train(self): self._hooked_sess = HookedSession(self.trainer.sess, self._hooks) self._input_callbacks.before_train() if self._size > 0: logger.info("InferenceRunner will eval {} iterations".format(self._size)) else: logger.warn("InferenceRunner got an input with unknown size! It will iterate until OutOfRangeError!")
def _initialize_session(self): # init the session self._config.session_init._setup_graph() self._sess = self._config.session_creator.create_session() self._config.session_init._run_init(self._sess) with self._sess.as_default(): self._input_callbacks.before_train() self._hooked_sess = HookedSession(self._sess, self._hooks)
def setup(self): super(MultiGPUTrainer, self).setup() self._training_aux_threads = [] self._training_aux_running = False self._training_aux_step_counter = itertools.count() from tensorpack.callbacks.hooks import CallbackToHook from tensorflow.python.training.monitored_session \ import _HookedSession as HookedSession for tidx, n in enumerate(self._train_ops_aux): auxTrainOp = self._train_ops_aux[n] logger.info("Create aux train ops {}".format(auxTrainOp._name)) if len(auxTrainOp._callbacks) > 0: auxTrainOp._sess = HookedSession( self.sess, hooks=[CallbackToHook(cb) for cb in auxTrainOp._callbacks]) else: auxTrainOp._sess = self.sess def f(op=auxTrainOp): # avoid late-binding try: op._sess.run([op._train_op ]) # TODO this won't work with StageInput except RuntimeError: # exited pass except tf.errors.CancelledError: pass except tf.errors.AbortedError: pass # next(self._training_aux_step_counter) # atomic due to GIL th = LoopThread(f) th.name = "AsyncLoopThread-{}".format(tidx) th.pause() th.start() logger.info("Start aux thread {}".format(auxTrainOp._name)) self._training_aux_threads.append(th)
def _before_train(self): self._hooked_sess = HookedSession(self.trainer.sess, self._hooks) self._input_callbacks.before_train()
def _before_train(self): super(DataParallelInferenceRunner, self)._before_train() self._parallel_hooked_sess = HookedSession(self.trainer.sess, self._hooks_parallel)
def _before_train(self): self._hooked_sess = HookedSession(self.trainer.sess, self._hooks) self._parallel_hooked_sess = HookedSession(self.trainer.sess, self._hooks_parallel)
def _before_train(self): self._hooks.extend(self._extra_hooks) self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
def _before_train(self): self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)