def _training_loop(self, sess, loop_status): """ At each iteration, an ``IterationMessage`` object is created to send network output to/receive controlling messages from self.app. The iteration message will be passed into `self.run_vars`, where graph elements to run are collected and feed into `session.run()`. A nested validation loop will be running if self.validation_every_n > 0. During the validation loop the network parameters remain unchanged. """ iter_msg = IterationMessage() # initialise tf summary writers writer_train = tf.summary.FileWriter( os.path.join(self.summary_dir, TRAIN), sess.graph) writer_valid = tf.summary.FileWriter( os.path.join(self.summary_dir, VALID), sess.graph) \ if self.validation_every_n > 0 else None for iter_i in range(self.initial_iter, self.final_iter): # general loop information loop_status['current_iter'] = iter_i if self._coord.should_stop(): break if iter_msg.should_stop: break # update variables/operations to run, from self.app iter_msg.current_iter, iter_msg.phase = iter_i, TRAIN self.run_vars(sess, iter_msg) self.app.interpret_output( iter_msg.current_iter_output[NETWORK_OUTPUT]) iter_msg.to_tf_summary(writer_train) tf.logging.info(iter_msg.to_console_string()) # run validations if required if iter_i > 0 and self.validation_every_n > 0 and \ (iter_i % self.validation_every_n == 0): for _ in range(self.validation_max_iter): iter_msg.current_iter, iter_msg.phase = iter_i, VALID self.run_vars(sess, iter_msg) # save iteration results if writer_valid is not None: iter_msg.to_tf_summary(writer_valid) tf.logging.info(iter_msg.to_console_string()) if self.save_every_n > 0 and (iter_i % self.save_every_n == 0): self._save_model(sess, iter_i)
def test_interfaces(self): msg = IterationMessage() msg.current_iter = 0 self.assertEqual(msg.current_iter, 0) self.assertEqual(msg.ops_to_run, {}) self.assertEqual(msg.data_feed_dict, {}) self.assertEqual(msg.current_iter_output, None) self.assertEqual(msg.should_stop, None) self.assertEqual(msg.phase, TRAIN) self.assertEqual(msg.is_training, True) self.assertEqual(msg.is_validation, False) self.assertEqual(msg.is_inference, False) msg.current_iter_output = {'test'} self.assertEqual(msg.current_iter_output, {'test'}) self.assertGreater(msg.iter_duration, 0.0) self.assertStartsWith(msg.to_console_string(), 'training') self.assertEqual(msg.to_tf_summary(0), None)
def test_interfaces(self): msg = IterationMessage() msg.current_iter = 0 self.assertEqual(msg.current_iter, 0) self.assertEqual(msg.ops_to_run, {}) self.assertEqual(msg.data_feed_dict, {}) self.assertEqual(msg.current_iter_output, None) self.assertEqual(msg.should_stop, False) self.assertEqual(msg.phase, TRAIN) self.assertEqual(msg.is_training, True) self.assertEqual(msg.is_validation, False) self.assertEqual(msg.is_inference, False) msg.current_iter_output = {'test'} self.assertEqual(msg.current_iter_output, {'test'}) self.assertGreater(msg.iter_duration, 0.0) self.assertStartsWith(msg.to_console_string(), 'Training') self.assertEqual(msg.to_tf_summary(0), None)