def test_run_vars(self): app_driver = get_initialised_driver() test_graph = app_driver._create_graph(app_driver.graph) test_tensor = app_driver.graph.get_tensor_by_name( "G/conv_bn_selu/conv_/w:0") with self.test_session(graph=test_graph) as sess: app_driver._run_sampler_threads() app_driver._run_sampler_threads(sess) sess.run(app_driver._init_op) iter_msg = IterationMessage() # run 1st training iter iter_msg.current_iter, iter_msg.phase = 1, TRAIN app_driver.run_vars(sess, iter_msg) model_value_1 = sess.run(test_tensor) self.assertGreater(iter_msg.iter_duration, 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Training') # run 2nd training iter iter_msg.current_iter, iter_msg.phase = 2, TRAIN app_driver.run_vars(sess, iter_msg) model_value_2 = sess.run(test_tensor) # make sure model gets updated self.assertNotAlmostEqual( np.mean(np.abs(model_value_1 - model_value_2)), 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Training') # run validation iter iter_msg.current_tier, iter_msg.phase = 3, VALID app_driver.run_vars(sess, iter_msg) model_value_3 = sess.run(test_tensor) # make sure model not gets udpated self.assertAlmostEqual( np.mean(np.abs(model_value_2 - model_value_3)), 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Validation') # run training iter iter_msg.current_iter, iter_msg.phase = 4, TRAIN app_driver.run_vars(sess, iter_msg) model_value_4 = sess.run(test_tensor) # make sure model gets updated self.assertNotAlmostEqual( np.mean(np.abs(model_value_2 - model_value_4)), 0.0) self.assertNotAlmostEqual( np.mean(np.abs(model_value_3 - model_value_4)), 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Training') app_driver.app.stop() self.assertEqual(iter_msg.ops_to_run, {})
def _training_loop(self, sess, loop_status): """ At each iteration, an ``IterationMessage`` object is created to send network output to/receive controlling messages from self.app. The iteration message will be passed into `self.run_vars`, where graph elements to run are collected and feed into `session.run()`. A nested validation loop will be running if self.validation_every_n > 0. During the validation loop the network parameters remain unchanged. """ iter_msg = IterationMessage() # initialise tf summary writers writer_train = tf.summary.FileWriter( os.path.join(self.summary_dir, TRAIN), sess.graph) writer_valid = tf.summary.FileWriter( os.path.join(self.summary_dir, VALID), sess.graph) \ if self.validation_every_n > 0 else None for iter_i in range(self.initial_iter, self.final_iter): # general loop information loop_status['current_iter'] = iter_i if self._coord.should_stop(): break if iter_msg.should_stop: break # update variables/operations to run, from self.app iter_msg.current_iter, iter_msg.phase = iter_i, TRAIN self.run_vars(sess, iter_msg) self.app.interpret_output( iter_msg.current_iter_output[NETWORK_OUTPUT]) iter_msg.to_tf_summary(writer_train) tf.logging.info(iter_msg.to_console_string()) # run validations if required if iter_i > 0 and self.validation_every_n > 0 and \ (iter_i % self.validation_every_n == 0): for _ in range(self.validation_max_iter): iter_msg.current_iter, iter_msg.phase = iter_i, VALID self.run_vars(sess, iter_msg) # save iteration results if writer_valid is not None: iter_msg.to_tf_summary(writer_valid) tf.logging.info(iter_msg.to_console_string()) if self.save_every_n > 0 and (iter_i % self.save_every_n == 0): self._save_model(sess, iter_i)
def test_set_fields(self): msg = IterationMessage() # setting iter will clear tic and iter output fields msg.current_iter = 3 self.assertGreater(msg._current_iter_tic, 0.0) self.assertEqual(msg._current_iter_output, None) # setting iter output will update iter duration msg.current_iter_output = {CONSOLE: {'test': 'test'}} self.assertEqual(msg.current_iter, 3) self.assertGreater(msg.iter_duration, 0.0) self.assertRegexpMatches(msg.to_console_string(), '.*test=test.*') with self.assertRaisesRegexp(ValueError, ''): msg.current_iter = 'test'
def _inference_loop(self, sess, loop_status): """ Runs all variables returned by outputs_collector, this loop stops when the return value of application.interpret_output is False. """ iter_msg = IterationMessage() loop_status['all_saved_flag'] = False iter_i = 0 while True: if self._coord.should_stop(): break if iter_msg.should_stop: break iter_msg.current_iter, iter_msg.phase = iter_i, INFER # run variables provided in `iter_msg` and set values of # variables to iter_msg.current_iter_output self.run_vars(sess, iter_msg) iter_i = iter_i + 1 # process the graph outputs if not self.app.interpret_output( iter_msg.current_iter_output[NETWORK_OUTPUT]): tf.logging.info('processed all batches.') loop_status['all_saved_flag'] = True break tf.logging.info(iter_msg.to_console_string())
def run(self, application, graph=None): """ Initialise a TF graph, connect data sampler and network within the graph context, run training loops or inference loops. :param application: a niftynet application :param graph: default base graph to run the application :return: """ if graph is None: graph = ApplicationDriver.create_graph( application=application, num_gpus=self.num_gpus, num_threads=self.num_threads, is_training_action=self.is_training_action) start_time = time.time() loop_status = {'current_iter': self.initial_iter, 'normal_exit': False} with tf.Session(config=tf_config(), graph=graph): try: # broadcasting event of session started SESS_STARTED.send(application, iter_msg=None) # create a iteration message generator and # iteratively run the graph (the main engine loop) iteration_messages = self._generator(**vars(self))() ApplicationDriver.loop(application=application, iteration_messages=iteration_messages, loop_status=loop_status) except KeyboardInterrupt: tf.logging.warning('User cancelled application') except (tf.errors.OutOfRangeError, EOFError): if not loop_status.get('normal_exit', False): # reached the end of inference Dataset loop_status['normal_exit'] = True except RuntimeError: import sys import traceback exc_type, exc_value, exc_traceback = sys.exc_info() traceback.print_exception(exc_type, exc_value, exc_traceback, file=sys.stdout) finally: tf.logging.info('cleaning up...') # broadcasting session finished event iter_msg = IterationMessage() iter_msg.current_iter = loop_status.get('current_iter', -1) SESS_FINISHED.send(application, iter_msg=iter_msg) application.stop() if not loop_status.get('normal_exit', False): # loop didn't finish normally tf.logging.warning('stopped early, incomplete iterations.') tf.logging.info("%s stopped (time in second %.2f).", type(application).__name__, (time.time() - start_time))
def test_interfaces(self): msg = IterationMessage() msg.current_iter = 0 self.assertEqual(msg.current_iter, 0) self.assertEqual(msg.ops_to_run, {}) self.assertEqual(msg.data_feed_dict, {}) self.assertEqual(msg.current_iter_output, None) self.assertEqual(msg.should_stop, None) self.assertEqual(msg.phase, TRAIN) self.assertEqual(msg.is_training, True) self.assertEqual(msg.is_validation, False) self.assertEqual(msg.is_inference, False) msg.current_iter_output = {'test'} self.assertEqual(msg.current_iter_output, {'test'}) self.assertGreater(msg.iter_duration, 0.0) self.assertStartsWith(msg.to_console_string(), 'training') self.assertEqual(msg.to_tf_summary(0), None)
def test_interfaces(self): msg = IterationMessage() msg.current_iter = 0 self.assertEqual(msg.current_iter, 0) self.assertEqual(msg.ops_to_run, {}) self.assertEqual(msg.data_feed_dict, {}) self.assertEqual(msg.current_iter_output, None) self.assertEqual(msg.should_stop, False) self.assertEqual(msg.phase, TRAIN) self.assertEqual(msg.is_training, True) self.assertEqual(msg.is_validation, False) self.assertEqual(msg.is_inference, False) msg.current_iter_output = {'test'} self.assertEqual(msg.current_iter_output, {'test'}) self.assertGreater(msg.iter_duration, 0.0) self.assertStartsWith(msg.to_console_string(), 'Training') self.assertEqual(msg.to_tf_summary(0), None)
def test_init(self): ITER_FINISHED.connect(self.iteration_listener) app_driver = get_initialised_driver() app_driver.load_event_handlers([ 'niftynet.engine.handler_model.ModelRestorer', 'niftynet.engine.handler_console.ConsoleLogger', 'niftynet.engine.handler_sampler.SamplerThreading' ]) graph = app_driver.create_graph(app_driver.app, 1, True) with self.cached_session(graph=graph) as sess: SESS_STARTED.send(app_driver.app, iter_msg=None) msg = IterationMessage() msg.current_iter = 1 app_driver.loop(app_driver.app, [msg]) app_driver.app.stop() ITER_FINISHED.disconnect(self.iteration_listener)
def iter_generator(count_generator, phase): """ Generate a numbered sequence of IterationMessage objects with phase-appropriate signals. count_generator is an iterable object yielding iteration numbers phase is one of TRAIN, VALID or INFER """ signals = {TRAIN: (ApplicationDriver.pre_train_iter, ApplicationDriver.post_train_iter), VALID: (ApplicationDriver.pre_validation_iter, ApplicationDriver.post_validation_iter), INFER: (ApplicationDriver.pre_infer_iter, ApplicationDriver.post_infer_iter)} for iter_i in count_generator: iter_msg = IterationMessage() iter_msg.current_iter, iter_msg.phase = iter_i, phase iter_msg.pre_iter = signals[phase][0] iter_msg.post_iter = signals[phase][1] yield iter_msg
def iter_generator(count_generator, phase): """ Generate a numbered sequence of IterationMessage objects with phase-appropriate signals. count_generator is an iterable object yielding iteration numbers phase is one of TRAIN, VALID or INFER """ signals = { TRAIN: (ApplicationDriver.pre_train_iter, ApplicationDriver.post_train_iter), VALID: (ApplicationDriver.pre_validation_iter, ApplicationDriver.post_validation_iter), INFER: (ApplicationDriver.pre_infer_iter, ApplicationDriver.post_infer_iter) } for iter_i in count_generator: iter_msg = IterationMessage() iter_msg.current_iter, iter_msg.phase = iter_i, phase iter_msg.pre_iter = signals[phase][0] iter_msg.post_iter = signals[phase][1] yield iter_msg
def run_application(self): """ Initialise a TF graph, connect data sampler and network within the graph context, run training loops or inference loops. The training loop terminates when ``self.final_iter`` reached. The inference loop terminates when there is no more image sample to be processed from image reader. :return: """ config = ApplicationDriver._tf_config() with tf.Session(config=config, graph=self.graph) as session: # start samplers' threads self._run_sampler_threads(session=session) self.graph = self._create_graph(self.graph) # check app variables initialised and ready for starts self.app.check_initialisations() # initialise network trainable parameters self._rand_init_or_restore_vars(session) start_time = time.time() loop_status = {} try: # iteratively run the graph if self.is_training: self.model_saver = ModelSaver(session, self.saver, self.save_every_n, self.session_prefix) loop_status['current_iter'] = self.initial_iter self._training_loop(session, loop_status) else: loop_status['all_saved_flag'] = False self._inference_loop(session, loop_status) except KeyboardInterrupt: tf.logging.warning('User cancelled application') except tf.errors.OutOfRangeError: if loop_status.get('all_saved_flag', None) is not None: # reached the end of inference Dataset loop_status['all_saved_flag'] = True except RuntimeError: import sys import traceback exc_type, exc_value, exc_traceback = sys.exc_info() traceback.print_exception( exc_type, exc_value, exc_traceback, file=sys.stdout) finally: tf.logging.info('Cleaning up...') if self.is_training: # saving model at the last iteration iter_msg = IterationMessage() iter_msg.current_iter = loop_status.get('current_iter', -1) self.post_training.send(iter_msg) elif not loop_status.get('all_saved_flag', None): tf.logging.warning('stopped early, incomplete loops') tf.logging.info('stopping sampling threads') self.app.stop() tf.logging.info( "%s stopped (time in second %.2f).", type(self.app).__name__, (time.time() - start_time))
def run_application(self): """ Initialise a TF graph, connect data sampler and network within the graph context, run training loops or inference loops. The training loop terminates when ``self.final_iter`` reached. The inference loop terminates when there is no more image sample to be processed from image reader. :return: """ config = ApplicationDriver._tf_config() with tf.Session(config=config, graph=self.graph) as session: # start samplers' threads self._run_sampler_threads(session=session) self.graph = self._create_graph(self.graph) # check app variables initialised and ready for starts self.app.check_initialisations() # initialise network trainable parameters self._rand_init_or_restore_vars(session) start_time = time.time() loop_status = {} try: # iteratively run the graph if self.is_training: self.model_saver = ModelSaver(session, self.saver, self.save_every_n, self.session_prefix) loop_status['current_iter'] = self.initial_iter self._training_loop(session, loop_status) else: loop_status['all_saved_flag'] = False self._inference_loop(session, loop_status) except KeyboardInterrupt: tf.logging.warning('User cancelled application') except tf.errors.OutOfRangeError: if loop_status.get('all_saved_flag', None) is not None: # reached the end of inference Dataset loop_status['all_saved_flag'] = True except RuntimeError: import sys import traceback exc_type, exc_value, exc_traceback = sys.exc_info() traceback.print_exception(exc_type, exc_value, exc_traceback, file=sys.stdout) finally: tf.logging.info('Cleaning up...') if self.is_training: # saving model at the last iteration iter_msg = IterationMessage() iter_msg.current_iter = loop_status.get('current_iter', -1) self.post_training.send(iter_msg) elif not loop_status.get('all_saved_flag', None): tf.logging.warning('stopped early, incomplete loops') tf.logging.info('stopping sampling threads') self.app.stop() tf.logging.info("%s stopped (time in second %.2f).", type(self.app).__name__, (time.time() - start_time))