def test_run_vars(self): app_driver = get_initialised_driver() test_graph = app_driver._create_graph(app_driver.graph) test_tensor = app_driver.graph.get_tensor_by_name( "G/conv_bn_selu/conv_/w:0") with self.test_session(graph=test_graph) as sess: app_driver._run_sampler_threads() app_driver._run_sampler_threads(sess) sess.run(app_driver._init_op) iter_msg = IterationMessage() # run 1st training iter iter_msg.current_iter, iter_msg.phase = 1, TRAIN app_driver.run_vars(sess, iter_msg) model_value_1 = sess.run(test_tensor) self.assertGreater(iter_msg.iter_duration, 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Training') # run 2nd training iter iter_msg.current_iter, iter_msg.phase = 2, TRAIN app_driver.run_vars(sess, iter_msg) model_value_2 = sess.run(test_tensor) # make sure model gets updated self.assertNotAlmostEqual( np.mean(np.abs(model_value_1 - model_value_2)), 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Training') # run validation iter iter_msg.current_tier, iter_msg.phase = 3, VALID app_driver.run_vars(sess, iter_msg) model_value_3 = sess.run(test_tensor) # make sure model not gets udpated self.assertAlmostEqual( np.mean(np.abs(model_value_2 - model_value_3)), 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Validation') # run training iter iter_msg.current_iter, iter_msg.phase = 4, TRAIN app_driver.run_vars(sess, iter_msg) model_value_4 = sess.run(test_tensor) # make sure model gets updated self.assertNotAlmostEqual( np.mean(np.abs(model_value_2 - model_value_4)), 0.0) self.assertNotAlmostEqual( np.mean(np.abs(model_value_3 - model_value_4)), 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Training') app_driver.app.stop() self.assertEqual(iter_msg.ops_to_run, {})
def test_run_vars(self): app_driver = get_initialised_driver() test_graph = app_driver._create_graph(app_driver.graph) test_tensor = app_driver.graph.get_tensor_by_name( "G/conv_bn_selu/conv_/w:0") with self.test_session(graph=test_graph) as sess: app_driver._run_sampler_threads() app_driver._run_sampler_threads(sess) sess.run(app_driver._init_op) iter_msg = IterationMessage() # run 1st training iter iter_msg.current_iter, iter_msg.phase = 1, TRAIN app_driver.run_vars(sess, iter_msg) model_value_1 = sess.run(test_tensor) self.assertGreater(iter_msg.iter_duration, 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Training') # run 2nd training iter iter_msg.current_iter, iter_msg.phase = 2, TRAIN app_driver.run_vars(sess, iter_msg) model_value_2 = sess.run(test_tensor) # make sure model gets updated self.assertNotAlmostEqual( np.mean(np.abs(model_value_1 - model_value_2)), 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Training') # run validation iter iter_msg.current_tier, iter_msg.phase = 3, VALID app_driver.run_vars(sess, iter_msg) model_value_3 = sess.run(test_tensor) # make sure model not gets udpated self.assertAlmostEqual( np.mean(np.abs(model_value_2 - model_value_3)), 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Validation') # run training iter iter_msg.current_iter, iter_msg.phase = 4, TRAIN app_driver.run_vars(sess, iter_msg) model_value_4 = sess.run(test_tensor) # make sure model gets updated self.assertNotAlmostEqual( np.mean(np.abs(model_value_2 - model_value_4)), 0.0) self.assertNotAlmostEqual( np.mean(np.abs(model_value_3 - model_value_4)), 0.0) print(iter_msg.to_console_string()) self.assertRegexpMatches(iter_msg.to_console_string(), 'Training') app_driver.app.stop() self.assertEqual(iter_msg.ops_to_run, {})
def test_run_vars(self): app_driver = get_initialised_driver() test_graph = app_driver._create_graph(app_driver.graph) test_tensor = app_driver.graph.get_tensor_by_name( "G/conv_bn_selu/conv_/w:0") iter_msgs = [[]] def get_iter_msgs(iter_msg): """" Captures iter_msg for testing""" iter_msgs[0].append(iter_msg) app_driver.post_train_iter.connect(get_iter_msgs) app_driver.post_validation_iter.connect(get_iter_msgs) app_driver.initial_iter=0 app_driver.final_iter=3 app_driver.validation_every_n = 2 app_driver.validation_max_iter = 1 loop_status={} with self.test_session(graph=test_graph) as sess: app_driver._run_sampler_threads() app_driver._run_sampler_threads(sess) sess.run(app_driver._init_op) test_vals = [[]] def get_test_value(iter_msg): test_vals[0].append(sess.run(test_tensor)) app_driver.post_train_iter.connect(get_test_value) app_driver.post_validation_iter.connect(get_test_value) app_driver._training_loop(sess, loop_status) # Check sequence of iterations self.assertRegexpMatches(iter_msgs[0][0].to_console_string(), 'Training') self.assertRegexpMatches(iter_msgs[0][1].to_console_string(), 'Training') self.assertRegexpMatches(iter_msgs[0][2].to_console_string(), 'Validation') self.assertRegexpMatches(iter_msgs[0][3].to_console_string(), 'Training') # Check durations for iter_msg in iter_msgs[0]: self.assertGreater(iter_msg.iter_duration, 0.0) # Check training changes test tensor self.assertNotAlmostEqual( np.mean(np.abs(test_vals[0][0] - test_vals[0][1])), 0.0) self.assertNotAlmostEqual( np.mean(np.abs(test_vals[0][2] - test_vals[0][3])), 0.0) # Check validation doesn't change test tensor self.assertAlmostEqual( np.mean(np.abs(test_vals[0][1] - test_vals[0][2])), 0.0) app_driver.app.stop()
def test_run_vars(self): app_driver = get_initialised_driver() test_graph = app_driver.create_graph(app_driver.app, 1, True) test_tensor = test_graph.get_tensor_by_name("G/conv_bn_selu/conv_/w:0") train_eval_msgs = [] test_vals = [] def get_iter_msgs(_sender, **msg): """" Captures iter_msg and model values for testing""" train_eval_msgs.append(msg['iter_msg']) test_vals.append(sess.run(test_tensor)) print(msg['iter_msg'].to_console_string()) ITER_FINISHED.connect(get_iter_msgs) with self.test_session(graph=test_graph) as sess: GRAPH_CREATED.send(app_driver.app, iter_msg=None) SESS_STARTED.send(app_driver.app, iter_msg=None) iterations = IterationMessageGenerator(initial_iter=0, final_iter=3, validation_every_n=2, validation_max_iter=1, is_training_action=True) app_driver.loop(app_driver.app, iterations()) # Check sequence of iterations self.assertRegexpMatches(train_eval_msgs[0].to_console_string(), 'training') self.assertRegexpMatches(train_eval_msgs[1].to_console_string(), 'training') self.assertRegexpMatches(train_eval_msgs[2].to_console_string(), 'validation') self.assertRegexpMatches(train_eval_msgs[3].to_console_string(), 'training') # Check durations for iter_msg in train_eval_msgs: self.assertGreater(iter_msg.iter_duration, 0.0) # Check training changes test tensor self.assertNotAlmostEqual( np.mean(np.abs(test_vals[0] - test_vals[1])), 0.0) self.assertNotAlmostEqual( np.mean(np.abs(test_vals[2] - test_vals[3])), 0.0) # Check validation doesn't change test tensor self.assertAlmostEqual( np.mean(np.abs(test_vals[1] - test_vals[2])), 0.0) app_driver.app.stop() ITER_FINISHED.disconnect(get_iter_msgs)
def test_init(self): ITER_FINISHED.connect(self.iteration_listener) app_driver = get_initialised_driver() app_driver.load_event_handlers([ 'niftynet.engine.handler_model.ModelRestorer', 'niftynet.engine.handler_console.ConsoleLogger', 'niftynet.engine.handler_sampler.SamplerThreading' ]) graph = app_driver.create_graph(app_driver.app, 1, True) with self.cached_session(graph=graph) as sess: SESS_STARTED.send(app_driver.app, iter_msg=None) msg = IterationMessage() msg.current_iter = 1 app_driver.loop(app_driver.app, [msg]) app_driver.app.stop() ITER_FINISHED.disconnect(self.iteration_listener)
def test_init(self): app_driver = get_initialised_driver() test_graph = app_driver.create_graph(app_driver.app, 1, True) app_driver.app.set_iteration_update = set_iteration_update app_driver.app.interpret_output = self.create_interpreter() app_driver.load_event_handlers([ 'niftynet.engine.handler_model.ModelRestorer', 'niftynet.engine.handler_network_output.OutputInterpreter', 'niftynet.engine.handler_sampler.SamplerThreading' ]) with self.test_session(graph=test_graph) as sess: SESS_STARTED.send(app_driver.app, iter_msg=None) iterator = IterationMessageGenerator(is_training_action=False) app_driver.loop(app_driver.app, iterator()) app_driver.app.stop()
def test_run_vars(self): app_driver = get_initialised_driver() test_graph = app_driver._create_graph(app_driver.graph) test_tensor = app_driver.graph.get_tensor_by_name( "G/conv_bn_selu/conv_/w:0") iter_msgs = [[]] def get_iter_msgs(iter_msg): """" Captures iter_msg for testing""" iter_msgs[0].append(iter_msg) app_driver.post_train_iter.connect(get_iter_msgs) app_driver.post_validation_iter.connect(get_iter_msgs) app_driver.initial_iter = 0 app_driver.final_iter = 3 app_driver.validation_every_n = 2 app_driver.validation_max_iter = 1 loop_status = {} with self.test_session(graph=test_graph) as sess: app_driver._run_sampler_threads() app_driver._run_sampler_threads(sess) sess.run(app_driver._init_op) test_vals = [[]] def get_test_value(iter_msg): test_vals[0].append(sess.run(test_tensor)) app_driver.post_train_iter.connect(get_test_value) app_driver.post_validation_iter.connect(get_test_value) app_driver._training_loop(sess, loop_status) # Check sequence of iterations self.assertRegexpMatches(iter_msgs[0][0].to_console_string(), 'Training') self.assertRegexpMatches(iter_msgs[0][1].to_console_string(), 'Training') self.assertRegexpMatches(iter_msgs[0][2].to_console_string(), 'Validation') self.assertRegexpMatches(iter_msgs[0][3].to_console_string(), 'Training') # Check durations for iter_msg in iter_msgs[0]: self.assertGreater(iter_msg.iter_duration, 0.0) # Check training changes test tensor self.assertNotAlmostEqual( np.mean(np.abs(test_vals[0][0] - test_vals[0][1])), 0.0) self.assertNotAlmostEqual( np.mean(np.abs(test_vals[0][2] - test_vals[0][3])), 0.0) # Check validation doesn't change test tensor self.assertAlmostEqual( np.mean(np.abs(test_vals[0][1] - test_vals[0][2])), 0.0) app_driver.app.stop()