Esempio n. 1
0
    def loop_step(application, iteration_message):
        """
        Calling ``tf.session.run`` with parameters encapsulated in
        iteration message as an iteration.
        Broadcasting ITER_* events before and afterward.

        :param application:
        :param iteration_message: an ``engine.IterationMessage`` instances
        :return:
        """
        # broadcasting event of starting an iteration
        ITER_STARTED.send(application, iter_msg=iteration_message)

        # ``iter_msg.ops_to_run`` are populated with the ops to run in
        # each iteration, fed into ``session.run()`` and then
        # passed to the application (and observers) for interpretation.
        sess = tf.get_default_session()
        assert sess, 'method should be called within a TF session context.'

        iteration_message.current_iter_output = sess.run(
            iteration_message.ops_to_run,
            feed_dict=iteration_message.data_feed_dict)

        # broadcasting event of finishing an iteration
        ITER_FINISHED.send(application, iter_msg=iteration_message)
Esempio n. 2
0
    def test_run_vars(self):
        app_driver = get_initialised_driver()
        test_graph = app_driver.create_graph(app_driver.app, 1, True)
        test_tensor = test_graph.get_tensor_by_name("G/conv_bn_selu/conv_/w:0")
        train_eval_msgs = []
        test_vals = []

        def get_iter_msgs(_sender, **msg):
            """" Captures iter_msg and model values for testing"""
            train_eval_msgs.append(msg['iter_msg'])
            test_vals.append(sess.run(test_tensor))
            print(msg['iter_msg'].to_console_string())

        ITER_FINISHED.connect(get_iter_msgs)

        with self.test_session(graph=test_graph) as sess:
            GRAPH_CREATED.send(app_driver.app, iter_msg=None)
            SESS_STARTED.send(app_driver.app, iter_msg=None)
            iterations = IterationMessageGenerator(initial_iter=0,
                                                   final_iter=3,
                                                   validation_every_n=2,
                                                   validation_max_iter=1,
                                                   is_training_action=True)
            app_driver.loop(app_driver.app, iterations())

            # Check sequence of iterations
            self.assertRegexpMatches(train_eval_msgs[0].to_console_string(),
                                     'training')
            self.assertRegexpMatches(train_eval_msgs[1].to_console_string(),
                                     'training')
            self.assertRegexpMatches(train_eval_msgs[2].to_console_string(),
                                     'validation')
            self.assertRegexpMatches(train_eval_msgs[3].to_console_string(),
                                     'training')

            # Check durations
            for iter_msg in train_eval_msgs:
                self.assertGreater(iter_msg.iter_duration, 0.0)

            # Check training changes test tensor
            self.assertNotAlmostEqual(
                np.mean(np.abs(test_vals[0] - test_vals[1])), 0.0)
            self.assertNotAlmostEqual(
                np.mean(np.abs(test_vals[2] - test_vals[3])), 0.0)

            # Check validation doesn't change test tensor
            self.assertAlmostEqual(
                np.mean(np.abs(test_vals[1] - test_vals[2])), 0.0)

            app_driver.app.stop()

        ITER_FINISHED.disconnect(get_iter_msgs)
Esempio n. 3
0
    def __init__(self,
                 model_dir=None,
                 initial_iter=0,
                 tensorboard_every_n=0,
                 **_unused):

        self.tensorboard_every_n = tensorboard_every_n
        # creating new summary subfolder if it's not finetuning
        self.summary_dir = get_latest_subfolder(
            os.path.join(model_dir, 'logs'), create_new=initial_iter == 0)
        self.writer_train = None
        self.writer_valid = None

        GRAPH_CREATED.connect(self.init_writer)
        ITER_STARTED.connect(self.read_tensorboard_op)
        ITER_FINISHED.connect(self.write_tensorboard)
Esempio n. 4
0
    def test_init(self):
        ITER_FINISHED.connect(self.iteration_listener)

        app_driver = get_initialised_driver()
        app_driver.load_event_handlers([
            'niftynet.engine.handler_model.ModelRestorer',
            'niftynet.engine.handler_console.ConsoleLogger',
            'niftynet.engine.handler_sampler.SamplerThreading'
        ])
        graph = app_driver.create_graph(app_driver.app, 1, True)
        with self.cached_session(graph=graph) as sess:
            SESS_STARTED.send(app_driver.app, iter_msg=None)
            msg = IterationMessage()
            msg.current_iter = 1
            app_driver.loop(app_driver.app, [msg])
        app_driver.app.stop()

        ITER_FINISHED.disconnect(self.iteration_listener)
Esempio n. 5
0
    def __init__(self,
                 model_dir,
                 save_every_n=0,
                 max_checkpoints=1,
                 is_training_action=True,
                 **_unused):

        self.save_every_n = save_every_n
        self.max_checkpoints = max_checkpoints
        self.file_name_prefix = make_model_name(model_dir)
        self.saver = None

        # initialise the saver after the graph finalised
        SESS_STARTED.connect(self.init_saver)
        # save the training model at a positive frequency
        if self.save_every_n > 0:
            ITER_FINISHED.connect(self.save_model_interval)
        # always save the final training model before exiting
        if is_training_action:
            SESS_FINISHED.connect(self.save_model)
Esempio n. 6
0
 def __init__(self, **_unused):
     ITER_STARTED.connect(self.read_console_vars)
     ITER_FINISHED.connect(self.print_console_vars)
 def __init__(self, **_unused):
     ITER_FINISHED.connect(self.check_criteria)
Esempio n. 8
0
 def __init__(self, **_unused):
     ITER_FINISHED.connect(self.update_performance_history)