Exemplo n.º 1
0
    def test_run_vars(self):
        app_driver = get_initialised_driver()
        test_graph = app_driver._create_graph(app_driver.graph)
        test_tensor = app_driver.graph.get_tensor_by_name(
            "G/conv_bn_selu/conv_/w:0")

        with self.test_session(graph=test_graph) as sess:
            app_driver._run_sampler_threads()
            app_driver._run_sampler_threads(sess)
            sess.run(app_driver._init_op)

            iter_msg = IterationMessage()

            # run 1st training iter
            iter_msg.current_iter, iter_msg.phase = 1, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_1 = sess.run(test_tensor)
            self.assertGreater(iter_msg.iter_duration, 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run 2nd training iter
            iter_msg.current_iter, iter_msg.phase = 2, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_2 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_1 - model_value_2)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run validation iter
            iter_msg.current_tier, iter_msg.phase = 3, VALID
            app_driver.run_vars(sess, iter_msg)
            model_value_3 = sess.run(test_tensor)
            # make sure model not gets udpated
            self.assertAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_3)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(),
                                     'Validation')

            # run training iter
            iter_msg.current_iter, iter_msg.phase = 4, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_4 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_4)), 0.0)
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_3 - model_value_4)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            app_driver.app.stop()
            self.assertEqual(iter_msg.ops_to_run, {})
Exemplo n.º 2
0
    def test_run_vars(self):
        app_driver = get_initialised_driver()
        test_graph = app_driver._create_graph(app_driver.graph)
        test_tensor = app_driver.graph.get_tensor_by_name(
            "G/conv_bn_selu/conv_/w:0")

        with self.test_session(graph=test_graph) as sess:
            app_driver._run_sampler_threads()
            app_driver._run_sampler_threads(sess)
            sess.run(app_driver._init_op)

            iter_msg = IterationMessage()

            # run 1st training iter
            iter_msg.current_iter, iter_msg.phase = 1, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_1 = sess.run(test_tensor)
            self.assertGreater(iter_msg.iter_duration, 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run 2nd training iter
            iter_msg.current_iter, iter_msg.phase = 2, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_2 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_1 - model_value_2)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run validation iter
            iter_msg.current_tier, iter_msg.phase = 3, VALID
            app_driver.run_vars(sess, iter_msg)
            model_value_3 = sess.run(test_tensor)
            # make sure model not gets udpated
            self.assertAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_3)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Validation')

            # run training iter
            iter_msg.current_iter, iter_msg.phase = 4, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_4 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_4)), 0.0)
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_3 - model_value_4)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            app_driver.app.stop()
            self.assertEqual(iter_msg.ops_to_run, {})
Exemplo n.º 3
0
    def test_run_vars(self):
        app_driver = get_initialised_driver()
        test_graph = app_driver._create_graph(app_driver.graph)
        test_tensor = app_driver.graph.get_tensor_by_name(
            "G/conv_bn_selu/conv_/w:0")

        iter_msgs = [[]]
        def get_iter_msgs(iter_msg):
            """" Captures iter_msg for testing"""
            iter_msgs[0].append(iter_msg)
        app_driver.post_train_iter.connect(get_iter_msgs)
        app_driver.post_validation_iter.connect(get_iter_msgs)

        app_driver.initial_iter=0
        app_driver.final_iter=3
        app_driver.validation_every_n = 2
        app_driver.validation_max_iter = 1
        loop_status={}

        with self.test_session(graph=test_graph) as sess:
            app_driver._run_sampler_threads()
            app_driver._run_sampler_threads(sess)
            sess.run(app_driver._init_op)

            test_vals = [[]]
            def get_test_value(iter_msg):
                test_vals[0].append(sess.run(test_tensor))
            app_driver.post_train_iter.connect(get_test_value)
            app_driver.post_validation_iter.connect(get_test_value)

            app_driver._training_loop(sess, loop_status)

            # Check sequence of iterations
            self.assertRegexpMatches(iter_msgs[0][0].to_console_string(), 'Training')
            self.assertRegexpMatches(iter_msgs[0][1].to_console_string(), 'Training')
            self.assertRegexpMatches(iter_msgs[0][2].to_console_string(), 'Validation')
            self.assertRegexpMatches(iter_msgs[0][3].to_console_string(), 'Training')

            # Check durations
            for iter_msg in iter_msgs[0]:
                self.assertGreater(iter_msg.iter_duration, 0.0)

            # Check training changes test tensor
            self.assertNotAlmostEqual(
                np.mean(np.abs(test_vals[0][0] - test_vals[0][1])), 0.0)
            self.assertNotAlmostEqual(
                np.mean(np.abs(test_vals[0][2] - test_vals[0][3])), 0.0)

            # Check validation doesn't change test tensor
            self.assertAlmostEqual(
                np.mean(np.abs(test_vals[0][1] - test_vals[0][2])), 0.0)

            app_driver.app.stop()
Exemplo n.º 4
0
    def test_run_vars(self):
        app_driver = get_initialised_driver()
        test_graph = app_driver.create_graph(app_driver.app, 1, True)
        test_tensor = test_graph.get_tensor_by_name("G/conv_bn_selu/conv_/w:0")
        train_eval_msgs = []
        test_vals = []

        def get_iter_msgs(_sender, **msg):
            """" Captures iter_msg and model values for testing"""
            train_eval_msgs.append(msg['iter_msg'])
            test_vals.append(sess.run(test_tensor))
            print(msg['iter_msg'].to_console_string())

        ITER_FINISHED.connect(get_iter_msgs)

        with self.test_session(graph=test_graph) as sess:
            GRAPH_CREATED.send(app_driver.app, iter_msg=None)
            SESS_STARTED.send(app_driver.app, iter_msg=None)
            iterations = IterationMessageGenerator(initial_iter=0,
                                                   final_iter=3,
                                                   validation_every_n=2,
                                                   validation_max_iter=1,
                                                   is_training_action=True)
            app_driver.loop(app_driver.app, iterations())

            # Check sequence of iterations
            self.assertRegexpMatches(train_eval_msgs[0].to_console_string(),
                                     'training')
            self.assertRegexpMatches(train_eval_msgs[1].to_console_string(),
                                     'training')
            self.assertRegexpMatches(train_eval_msgs[2].to_console_string(),
                                     'validation')
            self.assertRegexpMatches(train_eval_msgs[3].to_console_string(),
                                     'training')

            # Check durations
            for iter_msg in train_eval_msgs:
                self.assertGreater(iter_msg.iter_duration, 0.0)

            # Check training changes test tensor
            self.assertNotAlmostEqual(
                np.mean(np.abs(test_vals[0] - test_vals[1])), 0.0)
            self.assertNotAlmostEqual(
                np.mean(np.abs(test_vals[2] - test_vals[3])), 0.0)

            # Check validation doesn't change test tensor
            self.assertAlmostEqual(
                np.mean(np.abs(test_vals[1] - test_vals[2])), 0.0)

            app_driver.app.stop()

        ITER_FINISHED.disconnect(get_iter_msgs)
Exemplo n.º 5
0
    def test_init(self):
        ITER_FINISHED.connect(self.iteration_listener)

        app_driver = get_initialised_driver()
        app_driver.load_event_handlers([
            'niftynet.engine.handler_model.ModelRestorer',
            'niftynet.engine.handler_console.ConsoleLogger',
            'niftynet.engine.handler_sampler.SamplerThreading'
        ])
        graph = app_driver.create_graph(app_driver.app, 1, True)
        with self.cached_session(graph=graph) as sess:
            SESS_STARTED.send(app_driver.app, iter_msg=None)
            msg = IterationMessage()
            msg.current_iter = 1
            app_driver.loop(app_driver.app, [msg])
        app_driver.app.stop()

        ITER_FINISHED.disconnect(self.iteration_listener)
    def test_init(self):
        app_driver = get_initialised_driver()
        test_graph = app_driver.create_graph(app_driver.app, 1, True)

        app_driver.app.set_iteration_update = set_iteration_update
        app_driver.app.interpret_output = self.create_interpreter()

        app_driver.load_event_handlers([
            'niftynet.engine.handler_model.ModelRestorer',
            'niftynet.engine.handler_network_output.OutputInterpreter',
            'niftynet.engine.handler_sampler.SamplerThreading'
        ])
        with self.test_session(graph=test_graph) as sess:
            SESS_STARTED.send(app_driver.app, iter_msg=None)

            iterator = IterationMessageGenerator(is_training_action=False)
            app_driver.loop(app_driver.app, iterator())
        app_driver.app.stop()
Exemplo n.º 7
0
    def test_run_vars(self):
        app_driver = get_initialised_driver()
        test_graph = app_driver._create_graph(app_driver.graph)
        test_tensor = app_driver.graph.get_tensor_by_name(
            "G/conv_bn_selu/conv_/w:0")

        iter_msgs = [[]]

        def get_iter_msgs(iter_msg):
            """" Captures iter_msg for testing"""
            iter_msgs[0].append(iter_msg)

        app_driver.post_train_iter.connect(get_iter_msgs)
        app_driver.post_validation_iter.connect(get_iter_msgs)

        app_driver.initial_iter = 0
        app_driver.final_iter = 3
        app_driver.validation_every_n = 2
        app_driver.validation_max_iter = 1
        loop_status = {}

        with self.test_session(graph=test_graph) as sess:
            app_driver._run_sampler_threads()
            app_driver._run_sampler_threads(sess)
            sess.run(app_driver._init_op)

            test_vals = [[]]

            def get_test_value(iter_msg):
                test_vals[0].append(sess.run(test_tensor))

            app_driver.post_train_iter.connect(get_test_value)
            app_driver.post_validation_iter.connect(get_test_value)

            app_driver._training_loop(sess, loop_status)

            # Check sequence of iterations
            self.assertRegexpMatches(iter_msgs[0][0].to_console_string(),
                                     'Training')
            self.assertRegexpMatches(iter_msgs[0][1].to_console_string(),
                                     'Training')
            self.assertRegexpMatches(iter_msgs[0][2].to_console_string(),
                                     'Validation')
            self.assertRegexpMatches(iter_msgs[0][3].to_console_string(),
                                     'Training')

            # Check durations
            for iter_msg in iter_msgs[0]:
                self.assertGreater(iter_msg.iter_duration, 0.0)

            # Check training changes test tensor
            self.assertNotAlmostEqual(
                np.mean(np.abs(test_vals[0][0] - test_vals[0][1])), 0.0)
            self.assertNotAlmostEqual(
                np.mean(np.abs(test_vals[0][2] - test_vals[0][3])), 0.0)

            # Check validation doesn't change test tensor
            self.assertAlmostEqual(
                np.mean(np.abs(test_vals[0][1] - test_vals[0][2])), 0.0)

            app_driver.app.stop()