Пример #1
0
    def _training_loop(self, sess, loop_status):
        """
        At each iteration, an ``IterationMessage`` object is created
        to send network output to/receive controlling messages from self.app.
        The iteration message will be passed into `self.run_vars`,
        where graph elements to run are collected and feed into `session.run()`.
        A nested validation loop will be running
        if self.validation_every_n > 0.  During the validation loop
        the network parameters remain unchanged.
        """

        iter_msg = IterationMessage()

        # initialise tf summary writers
        writer_train = tf.summary.FileWriter(
            os.path.join(self.summary_dir, TRAIN), sess.graph)
        writer_valid = tf.summary.FileWriter(
            os.path.join(self.summary_dir, VALID), sess.graph) \
            if self.validation_every_n > 0 else None

        for iter_i in range(self.initial_iter, self.final_iter):
            # general loop information
            loop_status['current_iter'] = iter_i
            if self._coord.should_stop():
                break
            if iter_msg.should_stop:
                break

            # update variables/operations to run, from self.app
            iter_msg.current_iter, iter_msg.phase = iter_i, TRAIN
            self.run_vars(sess, iter_msg)

            self.app.interpret_output(
                iter_msg.current_iter_output[NETWORK_OUTPUT])
            iter_msg.to_tf_summary(writer_train)
            tf.logging.info(iter_msg.to_console_string())

            # run validations if required
            if iter_i > 0 and self.validation_every_n > 0 and \
                    (iter_i % self.validation_every_n == 0):
                for _ in range(self.validation_max_iter):
                    iter_msg.current_iter, iter_msg.phase = iter_i, VALID
                    self.run_vars(sess, iter_msg)
                    # save iteration results
                    if writer_valid is not None:
                        iter_msg.to_tf_summary(writer_valid)
                    tf.logging.info(iter_msg.to_console_string())

            if self.save_every_n > 0 and (iter_i % self.save_every_n == 0):
                self._save_model(sess, iter_i)
Пример #2
0
 def test_interfaces(self):
     msg = IterationMessage()
     msg.current_iter = 0
     self.assertEqual(msg.current_iter, 0)
     self.assertEqual(msg.ops_to_run, {})
     self.assertEqual(msg.data_feed_dict, {})
     self.assertEqual(msg.current_iter_output, None)
     self.assertEqual(msg.should_stop, None)
     self.assertEqual(msg.phase, TRAIN)
     self.assertEqual(msg.is_training, True)
     self.assertEqual(msg.is_validation, False)
     self.assertEqual(msg.is_inference, False)
     msg.current_iter_output = {'test'}
     self.assertEqual(msg.current_iter_output, {'test'})
     self.assertGreater(msg.iter_duration, 0.0)
     self.assertStartsWith(msg.to_console_string(), 'training')
     self.assertEqual(msg.to_tf_summary(0), None)
Пример #3
0
 def test_interfaces(self):
     msg = IterationMessage()
     msg.current_iter = 0
     self.assertEqual(msg.current_iter, 0)
     self.assertEqual(msg.ops_to_run, {})
     self.assertEqual(msg.data_feed_dict, {})
     self.assertEqual(msg.current_iter_output, None)
     self.assertEqual(msg.should_stop, False)
     self.assertEqual(msg.phase, TRAIN)
     self.assertEqual(msg.is_training, True)
     self.assertEqual(msg.is_validation, False)
     self.assertEqual(msg.is_inference, False)
     msg.current_iter_output = {'test'}
     self.assertEqual(msg.current_iter_output, {'test'})
     self.assertGreater(msg.iter_duration, 0.0)
     self.assertStartsWith(msg.to_console_string(), 'Training')
     self.assertEqual(msg.to_tf_summary(0), None)