Ejemplo n.º 1
0
    def _inference_loop(self, sess, loop_status):
        """
        Runs all variables returned by outputs_collector,
        this loop stops when the return value of
        application.interpret_output is False.
        """
        iter_msg = IterationMessage()
        loop_status['all_saved_flag'] = False
        iter_i = 0
        while True:
            if self._coord.should_stop():
                break
            if iter_msg.should_stop:
                break

            iter_msg.current_iter, iter_msg.phase = iter_i, INFER
            # run variables provided in `iter_msg` and set values of
            # variables to iter_msg.current_iter_output
            self.run_vars(sess, iter_msg)
            iter_i = iter_i + 1

            # process the graph outputs
            if not self.app.interpret_output(
                    iter_msg.current_iter_output[NETWORK_OUTPUT]):
                tf.logging.info('processed all batches.')
                loop_status['all_saved_flag'] = True
                break
            tf.logging.info(iter_msg.to_console_string())
Ejemplo n.º 2
0
    def run(self, application, graph=None):
        """
        Initialise a TF graph, connect data sampler and network within
        the graph context, run training loops or inference loops.

        :param application: a niftynet application
        :param graph: default base graph to run the application
        :return:
        """
        if graph is None:
            graph = ApplicationDriver.create_graph(
                application=application,
                num_gpus=self.num_gpus,
                num_threads=self.num_threads,
                is_training_action=self.is_training_action)

        start_time = time.time()
        loop_status = {'current_iter': self.initial_iter, 'normal_exit': False}

        with tf.Session(config=tf_config(), graph=graph):
            try:
                # broadcasting event of session started
                SESS_STARTED.send(application, iter_msg=None)

                # create a iteration message generator and
                # iteratively run the graph (the main engine loop)
                iteration_messages = self._generator(**vars(self))()
                ApplicationDriver.loop(application=application,
                                       iteration_messages=iteration_messages,
                                       loop_status=loop_status)

            except KeyboardInterrupt:
                tf.logging.warning('User cancelled application')
            except (tf.errors.OutOfRangeError, EOFError):
                if not loop_status.get('normal_exit', False):
                    # reached the end of inference Dataset
                    loop_status['normal_exit'] = True
            except RuntimeError:
                import sys
                import traceback
                exc_type, exc_value, exc_traceback = sys.exc_info()
                traceback.print_exception(exc_type,
                                          exc_value,
                                          exc_traceback,
                                          file=sys.stdout)
            finally:
                tf.logging.info('cleaning up...')
                # broadcasting session finished event
                iter_msg = IterationMessage()
                iter_msg.current_iter = loop_status.get('current_iter', -1)
                SESS_FINISHED.send(application, iter_msg=iter_msg)

        application.stop()
        if not loop_status.get('normal_exit', False):
            # loop didn't finish normally
            tf.logging.warning('stopped early, incomplete iterations.')
        tf.logging.info("%s stopped (time in second %.2f).",
                        type(application).__name__, (time.time() - start_time))
Ejemplo n.º 3
0
    def _training_loop(self, sess, loop_status):
        """
        At each iteration, an ``IterationMessage`` object is created
        to send network output to/receive controlling messages from self.app.
        The iteration message will be passed into `self.run_vars`,
        where graph elements to run are collected and feed into `session.run()`.
        A nested validation loop will be running
        if self.validation_every_n > 0.  During the validation loop
        the network parameters remain unchanged.
        """

        iter_msg = IterationMessage()

        # initialise tf summary writers
        writer_train = tf.summary.FileWriter(
            os.path.join(self.summary_dir, TRAIN), sess.graph)
        writer_valid = tf.summary.FileWriter(
            os.path.join(self.summary_dir, VALID), sess.graph) \
            if self.validation_every_n > 0 else None

        for iter_i in range(self.initial_iter, self.final_iter):
            # general loop information
            loop_status['current_iter'] = iter_i
            if self._coord.should_stop():
                break
            if iter_msg.should_stop:
                break

            # update variables/operations to run, from self.app
            iter_msg.current_iter, iter_msg.phase = iter_i, TRAIN
            self.run_vars(sess, iter_msg)

            self.app.interpret_output(
                iter_msg.current_iter_output[NETWORK_OUTPUT])
            iter_msg.to_tf_summary(writer_train)
            tf.logging.info(iter_msg.to_console_string())

            # run validations if required
            if iter_i > 0 and self.validation_every_n > 0 and \
                    (iter_i % self.validation_every_n == 0):
                for _ in range(self.validation_max_iter):
                    iter_msg.current_iter, iter_msg.phase = iter_i, VALID
                    self.run_vars(sess, iter_msg)
                    # save iteration results
                    if writer_valid is not None:
                        iter_msg.to_tf_summary(writer_valid)
                    tf.logging.info(iter_msg.to_console_string())

            if self.save_every_n > 0 and (iter_i % self.save_every_n == 0):
                self._save_model(sess, iter_i)
Ejemplo n.º 4
0
    def test_set_fields(self):
        msg = IterationMessage()

        # setting iter will clear tic and iter output fields
        msg.current_iter = 3
        self.assertGreater(msg._current_iter_tic, 0.0)
        self.assertEqual(msg._current_iter_output, None)

        # setting iter output will update iter duration
        msg.current_iter_output = {CONSOLE: {'test': 'test'}}
        self.assertEqual(msg.current_iter, 3)
        self.assertGreater(msg.iter_duration, 0.0)
        self.assertRegexpMatches(msg.to_console_string(), '.*test=test.*')

        with self.assertRaisesRegexp(ValueError, ''):
            msg.current_iter = 'test'
Ejemplo n.º 5
0
 def test_interfaces(self):
     msg = IterationMessage()
     msg.current_iter = 0
     self.assertEqual(msg.current_iter, 0)
     self.assertEqual(msg.ops_to_run, {})
     self.assertEqual(msg.data_feed_dict, {})
     self.assertEqual(msg.current_iter_output, None)
     self.assertEqual(msg.should_stop, False)
     self.assertEqual(msg.phase, TRAIN)
     self.assertEqual(msg.is_training, True)
     self.assertEqual(msg.is_validation, False)
     self.assertEqual(msg.is_inference, False)
     msg.current_iter_output = {'test'}
     self.assertEqual(msg.current_iter_output, {'test'})
     self.assertGreater(msg.iter_duration, 0.0)
     self.assertStartsWith(msg.to_console_string(), 'Training')
     self.assertEqual(msg.to_tf_summary(0), None)
Ejemplo n.º 6
0
    def test_init(self):
        ITER_FINISHED.connect(self.iteration_listener)

        app_driver = get_initialised_driver()
        app_driver.load_event_handlers([
            'niftynet.engine.handler_model.ModelRestorer',
            'niftynet.engine.handler_console.ConsoleLogger',
            'niftynet.engine.handler_sampler.SamplerThreading'
        ])
        graph = app_driver.create_graph(app_driver.app, 1, True)
        with self.cached_session(graph=graph) as sess:
            SESS_STARTED.send(app_driver.app, iter_msg=None)
            msg = IterationMessage()
            msg.current_iter = 1
            app_driver.loop(app_driver.app, [msg])
        app_driver.app.stop()

        ITER_FINISHED.disconnect(self.iteration_listener)
Ejemplo n.º 7
0
def iter_generator(count_generator, phase):
    """ Generate a numbered sequence of IterationMessage objects
    with phase-appropriate signals.
    count_generator is an iterable object yielding iteration numbers
    phase is one of TRAIN, VALID or INFER
    """
    signals = {TRAIN: (ApplicationDriver.pre_train_iter,
                       ApplicationDriver.post_train_iter),
               VALID: (ApplicationDriver.pre_validation_iter,
                       ApplicationDriver.post_validation_iter),
               INFER: (ApplicationDriver.pre_infer_iter,
                       ApplicationDriver.post_infer_iter)}
    for iter_i in count_generator:
        iter_msg = IterationMessage()
        iter_msg.current_iter, iter_msg.phase = iter_i, phase
        iter_msg.pre_iter = signals[phase][0]
        iter_msg.post_iter = signals[phase][1]
        yield iter_msg
Ejemplo n.º 8
0
def iter_generator(count_generator, phase):
    """ Generate a numbered sequence of IterationMessage objects
    with phase-appropriate signals.
    count_generator is an iterable object yielding iteration numbers
    phase is one of TRAIN, VALID or INFER
    """
    signals = {
        TRAIN:
        (ApplicationDriver.pre_train_iter, ApplicationDriver.post_train_iter),
        VALID: (ApplicationDriver.pre_validation_iter,
                ApplicationDriver.post_validation_iter),
        INFER:
        (ApplicationDriver.pre_infer_iter, ApplicationDriver.post_infer_iter)
    }
    for iter_i in count_generator:
        iter_msg = IterationMessage()
        iter_msg.current_iter, iter_msg.phase = iter_i, phase
        iter_msg.pre_iter = signals[phase][0]
        iter_msg.post_iter = signals[phase][1]
        yield iter_msg
Ejemplo n.º 9
0
    def test_set_fields(self):
        msg = IterationMessage()

        # setting iter will clear tic and iter output fields
        msg.current_iter = 3
        self.assertGreater(msg._current_iter_tic, 0.0)
        self.assertEqual(msg._current_iter_output, None)

        # setting iter output will update iter duration
        msg.current_iter_output = {CONSOLE: {'test': 'test'}}
        self.assertEqual(msg.current_iter, 3)
        self.assertGreater(msg.iter_duration, 0.0)
        self.assertRegexpMatches(msg.to_console_string(), '.*test=test.*')

        with self.assertRaisesRegexp(ValueError, ''):
            msg.current_iter = 'test'
Ejemplo n.º 10
0
 def test_interfaces(self):
     msg = IterationMessage()
     msg.current_iter = 0
     self.assertEqual(msg.current_iter, 0)
     self.assertEqual(msg.ops_to_run, {})
     self.assertEqual(msg.data_feed_dict, {})
     self.assertEqual(msg.current_iter_output, None)
     self.assertEqual(msg.should_stop, None)
     self.assertEqual(msg.phase, TRAIN)
     self.assertEqual(msg.is_training, True)
     self.assertEqual(msg.is_validation, False)
     self.assertEqual(msg.is_inference, False)
     msg.current_iter_output = {'test'}
     self.assertEqual(msg.current_iter_output, {'test'})
     self.assertGreater(msg.iter_duration, 0.0)
     self.assertStartsWith(msg.to_console_string(), 'training')
     self.assertEqual(msg.to_tf_summary(0), None)
Ejemplo n.º 11
0
    def run_application(self):
        """
        Initialise a TF graph, connect data sampler and network within
        the graph context, run training loops or inference loops.

        The training loop terminates when ``self.final_iter`` reached.
        The inference loop terminates when there is no more
        image sample to be processed from image reader.

        :return:
        """
        config = ApplicationDriver._tf_config()
        with tf.Session(config=config, graph=self.graph) as session:

            # start samplers' threads
            self._run_sampler_threads(session=session)
            self.graph = self._create_graph(self.graph)

            # check app variables initialised and ready for starts
            self.app.check_initialisations()

            # initialise network trainable parameters
            self._rand_init_or_restore_vars(session)

            start_time = time.time()
            loop_status = {}
            try:
                # iteratively run the graph
                if self.is_training:
                    self.model_saver = ModelSaver(session, self.saver,
                                                  self.save_every_n,
                                                  self.session_prefix)
                    loop_status['current_iter'] = self.initial_iter
                    self._training_loop(session, loop_status)
                else:
                    loop_status['all_saved_flag'] = False
                    self._inference_loop(session, loop_status)

            except KeyboardInterrupt:
                tf.logging.warning('User cancelled application')
            except tf.errors.OutOfRangeError:
                if loop_status.get('all_saved_flag', None) is not None:
                    # reached the end of inference Dataset
                    loop_status['all_saved_flag'] = True
            except RuntimeError:
                import sys
                import traceback
                exc_type, exc_value, exc_traceback = sys.exc_info()
                traceback.print_exception(
                    exc_type, exc_value, exc_traceback, file=sys.stdout)
            finally:
                tf.logging.info('Cleaning up...')
                if self.is_training:
                    # saving model at the last iteration
                    iter_msg = IterationMessage()
                    iter_msg.current_iter = loop_status.get('current_iter', -1)
                    self.post_training.send(iter_msg)
                elif not loop_status.get('all_saved_flag', None):
                    tf.logging.warning('stopped early, incomplete loops')

                tf.logging.info('stopping sampling threads')
                self.app.stop()
                tf.logging.info(
                    "%s stopped (time in second %.2f).",
                    type(self.app).__name__, (time.time() - start_time))
Ejemplo n.º 12
0
    def test_run_vars(self):
        app_driver = get_initialised_driver()
        test_graph = app_driver._create_graph(app_driver.graph)
        test_tensor = app_driver.graph.get_tensor_by_name(
            "G/conv_bn_selu/conv_/w:0")

        with self.test_session(graph=test_graph) as sess:
            app_driver._run_sampler_threads()
            app_driver._run_sampler_threads(sess)
            sess.run(app_driver._init_op)

            iter_msg = IterationMessage()

            # run 1st training iter
            iter_msg.current_iter, iter_msg.phase = 1, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_1 = sess.run(test_tensor)
            self.assertGreater(iter_msg.iter_duration, 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run 2nd training iter
            iter_msg.current_iter, iter_msg.phase = 2, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_2 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_1 - model_value_2)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run validation iter
            iter_msg.current_tier, iter_msg.phase = 3, VALID
            app_driver.run_vars(sess, iter_msg)
            model_value_3 = sess.run(test_tensor)
            # make sure model not gets udpated
            self.assertAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_3)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Validation')

            # run training iter
            iter_msg.current_iter, iter_msg.phase = 4, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_4 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_4)), 0.0)
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_3 - model_value_4)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            app_driver.app.stop()
            self.assertEqual(iter_msg.ops_to_run, {})
Ejemplo n.º 13
0
    def test_run_vars(self):
        app_driver = get_initialised_driver()
        test_graph = app_driver._create_graph(app_driver.graph)
        test_tensor = app_driver.graph.get_tensor_by_name(
            "G/conv_bn_selu/conv_/w:0")

        with self.test_session(graph=test_graph) as sess:
            app_driver._run_sampler_threads()
            app_driver._run_sampler_threads(sess)
            sess.run(app_driver._init_op)

            iter_msg = IterationMessage()

            # run 1st training iter
            iter_msg.current_iter, iter_msg.phase = 1, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_1 = sess.run(test_tensor)
            self.assertGreater(iter_msg.iter_duration, 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run 2nd training iter
            iter_msg.current_iter, iter_msg.phase = 2, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_2 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_1 - model_value_2)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run validation iter
            iter_msg.current_tier, iter_msg.phase = 3, VALID
            app_driver.run_vars(sess, iter_msg)
            model_value_3 = sess.run(test_tensor)
            # make sure model not gets udpated
            self.assertAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_3)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(),
                                     'Validation')

            # run training iter
            iter_msg.current_iter, iter_msg.phase = 4, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_4 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_4)), 0.0)
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_3 - model_value_4)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            app_driver.app.stop()
            self.assertEqual(iter_msg.ops_to_run, {})
Ejemplo n.º 14
0
    def run_application(self):
        """
        Initialise a TF graph, connect data sampler and network within
        the graph context, run training loops or inference loops.

        The training loop terminates when ``self.final_iter`` reached.
        The inference loop terminates when there is no more
        image sample to be processed from image reader.

        :return:
        """
        config = ApplicationDriver._tf_config()
        with tf.Session(config=config, graph=self.graph) as session:

            # start samplers' threads
            self._run_sampler_threads(session=session)
            self.graph = self._create_graph(self.graph)

            # check app variables initialised and ready for starts
            self.app.check_initialisations()

            # initialise network trainable parameters
            self._rand_init_or_restore_vars(session)

            start_time = time.time()
            loop_status = {}
            try:
                # iteratively run the graph
                if self.is_training:
                    self.model_saver = ModelSaver(session, self.saver,
                                                  self.save_every_n,
                                                  self.session_prefix)
                    loop_status['current_iter'] = self.initial_iter
                    self._training_loop(session, loop_status)
                else:
                    loop_status['all_saved_flag'] = False
                    self._inference_loop(session, loop_status)

            except KeyboardInterrupt:
                tf.logging.warning('User cancelled application')
            except tf.errors.OutOfRangeError:
                if loop_status.get('all_saved_flag', None) is not None:
                    # reached the end of inference Dataset
                    loop_status['all_saved_flag'] = True
            except RuntimeError:
                import sys
                import traceback
                exc_type, exc_value, exc_traceback = sys.exc_info()
                traceback.print_exception(exc_type,
                                          exc_value,
                                          exc_traceback,
                                          file=sys.stdout)
            finally:
                tf.logging.info('Cleaning up...')
                if self.is_training:
                    # saving model at the last iteration
                    iter_msg = IterationMessage()
                    iter_msg.current_iter = loop_status.get('current_iter', -1)
                    self.post_training.send(iter_msg)
                elif not loop_status.get('all_saved_flag', None):
                    tf.logging.warning('stopped early, incomplete loops')

                tf.logging.info('stopping sampling threads')
                self.app.stop()
                tf.logging.info("%s stopped (time in second %.2f).",
                                type(self.app).__name__,
                                (time.time() - start_time))