Ejemplo n.º 1
0
    def _training_loop(self, sess, loop_status):
        """
        At each iteration, an ``IterationMessage`` object is created
        to send network output to/receive controlling messages from self.app.
        The iteration message will be passed into `self.run_vars`,
        where graph elements to run are collected and feed into `session.run()`.
        A nested validation loop will be running
        if self.validation_every_n > 0.  During the validation loop
        the network parameters remain unchanged.
        """

        iter_msg = IterationMessage()

        # initialise tf summary writers
        writer_train = tf.summary.FileWriter(
            os.path.join(self.summary_dir, TRAIN), sess.graph)
        writer_valid = tf.summary.FileWriter(
            os.path.join(self.summary_dir, VALID), sess.graph) \
            if self.validation_every_n > 0 else None

        for iter_i in range(self.initial_iter, self.final_iter):
            # general loop information
            loop_status['current_iter'] = iter_i
            if self._coord.should_stop():
                break
            if iter_msg.should_stop:
                break

            # update variables/operations to run, from self.app
            iter_msg.current_iter, iter_msg.phase = iter_i, TRAIN
            self.run_vars(sess, iter_msg)

            self.app.interpret_output(
                iter_msg.current_iter_output[NETWORK_OUTPUT])
            iter_msg.to_tf_summary(writer_train)
            tf.logging.info(iter_msg.to_console_string())

            # run validations if required
            if iter_i > 0 and self.validation_every_n > 0 and \
                    (iter_i % self.validation_every_n == 0):
                for _ in range(self.validation_max_iter):
                    iter_msg.current_iter, iter_msg.phase = iter_i, VALID
                    self.run_vars(sess, iter_msg)
                    # save iteration results
                    if writer_valid is not None:
                        iter_msg.to_tf_summary(writer_valid)
                    tf.logging.info(iter_msg.to_console_string())

            if self.save_every_n > 0 and (iter_i % self.save_every_n == 0):
                self._save_model(sess, iter_i)
Ejemplo n.º 2
0
    def _inference_loop(self, sess, loop_status):
        """
        Runs all variables returned by outputs_collector,
        this loop stops when the return value of
        application.interpret_output is False.
        """
        iter_msg = IterationMessage()
        loop_status['all_saved_flag'] = False
        iter_i = 0
        while True:
            if self._coord.should_stop():
                break
            if iter_msg.should_stop:
                break

            iter_msg.current_iter, iter_msg.phase = iter_i, INFER
            # run variables provided in `iter_msg` and set values of
            # variables to iter_msg.current_iter_output
            self.run_vars(sess, iter_msg)
            iter_i = iter_i + 1

            # process the graph outputs
            if not self.app.interpret_output(
                    iter_msg.current_iter_output[NETWORK_OUTPUT]):
                tf.logging.info('processed all batches.')
                loop_status['all_saved_flag'] = True
                break
            tf.logging.info(iter_msg.to_console_string())
Ejemplo n.º 3
0
    def test_run_vars(self):
        app_driver = get_initialised_driver()
        test_graph = app_driver._create_graph(app_driver.graph)
        test_tensor = app_driver.graph.get_tensor_by_name(
            "G/conv_bn_selu/conv_/w:0")

        with self.test_session(graph=test_graph) as sess:
            app_driver._run_sampler_threads()
            app_driver._run_sampler_threads(sess)
            sess.run(app_driver._init_op)

            iter_msg = IterationMessage()

            # run 1st training iter
            iter_msg.current_iter, iter_msg.phase = 1, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_1 = sess.run(test_tensor)
            self.assertGreater(iter_msg.iter_duration, 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run 2nd training iter
            iter_msg.current_iter, iter_msg.phase = 2, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_2 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_1 - model_value_2)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run validation iter
            iter_msg.current_tier, iter_msg.phase = 3, VALID
            app_driver.run_vars(sess, iter_msg)
            model_value_3 = sess.run(test_tensor)
            # make sure model not gets udpated
            self.assertAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_3)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(),
                                     'Validation')

            # run training iter
            iter_msg.current_iter, iter_msg.phase = 4, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_4 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_4)), 0.0)
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_3 - model_value_4)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            app_driver.app.stop()
            self.assertEqual(iter_msg.ops_to_run, {})
Ejemplo n.º 4
0
    def test_run_vars(self):
        app_driver = get_initialised_driver()
        test_graph = app_driver._create_graph(app_driver.graph)
        test_tensor = app_driver.graph.get_tensor_by_name(
            "G/conv_bn_selu/conv_/w:0")

        with self.test_session(graph=test_graph) as sess:
            app_driver._run_sampler_threads()
            app_driver._run_sampler_threads(sess)
            sess.run(app_driver._init_op)

            iter_msg = IterationMessage()

            # run 1st training iter
            iter_msg.current_iter, iter_msg.phase = 1, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_1 = sess.run(test_tensor)
            self.assertGreater(iter_msg.iter_duration, 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run 2nd training iter
            iter_msg.current_iter, iter_msg.phase = 2, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_2 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_1 - model_value_2)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            # run validation iter
            iter_msg.current_tier, iter_msg.phase = 3, VALID
            app_driver.run_vars(sess, iter_msg)
            model_value_3 = sess.run(test_tensor)
            # make sure model not gets udpated
            self.assertAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_3)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Validation')

            # run training iter
            iter_msg.current_iter, iter_msg.phase = 4, TRAIN
            app_driver.run_vars(sess, iter_msg)
            model_value_4 = sess.run(test_tensor)
            # make sure model gets updated
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_2 - model_value_4)), 0.0)
            self.assertNotAlmostEqual(
                np.mean(np.abs(model_value_3 - model_value_4)), 0.0)
            print(iter_msg.to_console_string())
            self.assertRegexpMatches(iter_msg.to_console_string(), 'Training')

            app_driver.app.stop()
            self.assertEqual(iter_msg.ops_to_run, {})
Ejemplo n.º 5
0
    def test_set_fields(self):
        msg = IterationMessage()

        # setting iter will clear tic and iter output fields
        msg.current_iter = 3
        self.assertGreater(msg._current_iter_tic, 0.0)
        self.assertEqual(msg._current_iter_output, None)

        # setting iter output will update iter duration
        msg.current_iter_output = {CONSOLE: {'test': 'test'}}
        self.assertEqual(msg.current_iter, 3)
        self.assertGreater(msg.iter_duration, 0.0)
        self.assertRegexpMatches(msg.to_console_string(), '.*test=test.*')

        with self.assertRaisesRegexp(ValueError, ''):
            msg.current_iter = 'test'
Ejemplo n.º 6
0
    def test_set_fields(self):
        msg = IterationMessage()

        # setting iter will clear tic and iter output fields
        msg.current_iter = 3
        self.assertGreater(msg._current_iter_tic, 0.0)
        self.assertEqual(msg._current_iter_output, None)

        # setting iter output will update iter duration
        msg.current_iter_output = {CONSOLE: {'test': 'test'}}
        self.assertEqual(msg.current_iter, 3)
        self.assertGreater(msg.iter_duration, 0.0)
        self.assertRegexpMatches(msg.to_console_string(), '.*test=test.*')

        with self.assertRaisesRegexp(ValueError, ''):
            msg.current_iter = 'test'
Ejemplo n.º 7
0
 def test_interfaces(self):
     msg = IterationMessage()
     msg.current_iter = 0
     self.assertEqual(msg.current_iter, 0)
     self.assertEqual(msg.ops_to_run, {})
     self.assertEqual(msg.data_feed_dict, {})
     self.assertEqual(msg.current_iter_output, None)
     self.assertEqual(msg.should_stop, None)
     self.assertEqual(msg.phase, TRAIN)
     self.assertEqual(msg.is_training, True)
     self.assertEqual(msg.is_validation, False)
     self.assertEqual(msg.is_inference, False)
     msg.current_iter_output = {'test'}
     self.assertEqual(msg.current_iter_output, {'test'})
     self.assertGreater(msg.iter_duration, 0.0)
     self.assertStartsWith(msg.to_console_string(), 'training')
     self.assertEqual(msg.to_tf_summary(0), None)
Ejemplo n.º 8
0
 def test_interfaces(self):
     msg = IterationMessage()
     msg.current_iter = 0
     self.assertEqual(msg.current_iter, 0)
     self.assertEqual(msg.ops_to_run, {})
     self.assertEqual(msg.data_feed_dict, {})
     self.assertEqual(msg.current_iter_output, None)
     self.assertEqual(msg.should_stop, False)
     self.assertEqual(msg.phase, TRAIN)
     self.assertEqual(msg.is_training, True)
     self.assertEqual(msg.is_validation, False)
     self.assertEqual(msg.is_inference, False)
     msg.current_iter_output = {'test'}
     self.assertEqual(msg.current_iter_output, {'test'})
     self.assertGreater(msg.iter_duration, 0.0)
     self.assertStartsWith(msg.to_console_string(), 'Training')
     self.assertEqual(msg.to_tf_summary(0), None)