def loop_step(application, iteration_message): """ Calling ``tf.session.run`` with parameters encapsulated in iteration message as an iteration. Broadcasting ITER_* events before and afterward. :param application: :param iteration_message: an ``engine.IterationMessage`` instances :return: """ # broadcasting event of starting an iteration ITER_STARTED.send(application, iter_msg=iteration_message) # ``iter_msg.ops_to_run`` are populated with the ops to run in # each iteration, fed into ``session.run()`` and then # passed to the application (and observers) for interpretation. sess = tf.get_default_session() assert sess, 'method should be called within a TF session context.' iteration_message.current_iter_output = sess.run( iteration_message.ops_to_run, feed_dict=iteration_message.data_feed_dict) # broadcasting event of finishing an iteration ITER_FINISHED.send(application, iter_msg=iteration_message)
def test_run_vars(self): app_driver = get_initialised_driver() test_graph = app_driver.create_graph(app_driver.app, 1, True) test_tensor = test_graph.get_tensor_by_name("G/conv_bn_selu/conv_/w:0") train_eval_msgs = [] test_vals = [] def get_iter_msgs(_sender, **msg): """" Captures iter_msg and model values for testing""" train_eval_msgs.append(msg['iter_msg']) test_vals.append(sess.run(test_tensor)) print(msg['iter_msg'].to_console_string()) ITER_FINISHED.connect(get_iter_msgs) with self.test_session(graph=test_graph) as sess: GRAPH_CREATED.send(app_driver.app, iter_msg=None) SESS_STARTED.send(app_driver.app, iter_msg=None) iterations = IterationMessageGenerator(initial_iter=0, final_iter=3, validation_every_n=2, validation_max_iter=1, is_training_action=True) app_driver.loop(app_driver.app, iterations()) # Check sequence of iterations self.assertRegexpMatches(train_eval_msgs[0].to_console_string(), 'training') self.assertRegexpMatches(train_eval_msgs[1].to_console_string(), 'training') self.assertRegexpMatches(train_eval_msgs[2].to_console_string(), 'validation') self.assertRegexpMatches(train_eval_msgs[3].to_console_string(), 'training') # Check durations for iter_msg in train_eval_msgs: self.assertGreater(iter_msg.iter_duration, 0.0) # Check training changes test tensor self.assertNotAlmostEqual( np.mean(np.abs(test_vals[0] - test_vals[1])), 0.0) self.assertNotAlmostEqual( np.mean(np.abs(test_vals[2] - test_vals[3])), 0.0) # Check validation doesn't change test tensor self.assertAlmostEqual( np.mean(np.abs(test_vals[1] - test_vals[2])), 0.0) app_driver.app.stop() ITER_FINISHED.disconnect(get_iter_msgs)
def __init__(self, model_dir=None, initial_iter=0, tensorboard_every_n=0, **_unused): self.tensorboard_every_n = tensorboard_every_n # creating new summary subfolder if it's not finetuning self.summary_dir = get_latest_subfolder( os.path.join(model_dir, 'logs'), create_new=initial_iter == 0) self.writer_train = None self.writer_valid = None GRAPH_CREATED.connect(self.init_writer) ITER_STARTED.connect(self.read_tensorboard_op) ITER_FINISHED.connect(self.write_tensorboard)
def test_init(self): ITER_FINISHED.connect(self.iteration_listener) app_driver = get_initialised_driver() app_driver.load_event_handlers([ 'niftynet.engine.handler_model.ModelRestorer', 'niftynet.engine.handler_console.ConsoleLogger', 'niftynet.engine.handler_sampler.SamplerThreading' ]) graph = app_driver.create_graph(app_driver.app, 1, True) with self.cached_session(graph=graph) as sess: SESS_STARTED.send(app_driver.app, iter_msg=None) msg = IterationMessage() msg.current_iter = 1 app_driver.loop(app_driver.app, [msg]) app_driver.app.stop() ITER_FINISHED.disconnect(self.iteration_listener)
def __init__(self, model_dir, save_every_n=0, max_checkpoints=1, is_training_action=True, **_unused): self.save_every_n = save_every_n self.max_checkpoints = max_checkpoints self.file_name_prefix = make_model_name(model_dir) self.saver = None # initialise the saver after the graph finalised SESS_STARTED.connect(self.init_saver) # save the training model at a positive frequency if self.save_every_n > 0: ITER_FINISHED.connect(self.save_model_interval) # always save the final training model before exiting if is_training_action: SESS_FINISHED.connect(self.save_model)
def __init__(self, **_unused): ITER_STARTED.connect(self.read_console_vars) ITER_FINISHED.connect(self.print_console_vars)
def __init__(self, **_unused): ITER_FINISHED.connect(self.check_criteria)
def __init__(self, **_unused): ITER_FINISHED.connect(self.update_performance_history)