Esempio n. 1
0
 def test_phases_feed(self):
     score = tf.placeholder(tf.float32, [])
     loop = tools.Loop(None)
     loop.add_phase('phase_1',
                    done=True,
                    score=score,
                    summary='',
                    steps=1,
                    report_every=1,
                    log_every=None,
                    checkpoint_every=None,
                    feed={score: 1})
     loop.add_phase('phase_2',
                    done=True,
                    score=score,
                    summary='',
                    steps=3,
                    report_every=1,
                    log_every=None,
                    checkpoint_every=None,
                    feed={score: 2})
     loop.add_phase('phase_3',
                    done=True,
                    score=score,
                    summary='',
                    steps=2,
                    report_every=1,
                    log_every=None,
                    checkpoint_every=None,
                    feed={score: 3})
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         scores = list(loop.run(sess, saver=None, max_step=15))
     self.assertAllEqual([1, 2, 2, 2, 3, 3, 1, 2, 2, 2, 3, 3, 1, 2, 2],
                         scores)
Esempio n. 2
0
def initialize_variables(sess, saver, logdir, checkpoint=None, resume=None):
    """Initialize or restore variables from a checkpoint if available.

  Args:
    sess: Session to initialize variables in.
    saver: Saver to restore variables.
    logdir: Directory to search for checkpoints.
    checkpoint: Specify what checkpoint name to use; defaults to most recent.
    resume: Whether to expect recovering a checkpoint or starting a new run.

  Raises:
    ValueError: If resume expected but no log directory specified.
    RuntimeError: If no resume expected but a checkpoint was found.
  """
    sess.run(
        tf.group(tf.local_variables_initializer(),
                 tf.global_variables_initializer()))
    if resume and not (logdir or checkpoint):
        raise ValueError('Need to specify logdir to resume a checkpoint.')
    if logdir:
        state = tf.train.get_checkpoint_state(logdir)
        if checkpoint:
            checkpoint = os.path.join(logdir, checkpoint)
        if not checkpoint and state and state.model_checkpoint_path:
            checkpoint = state.model_checkpoint_path
        if checkpoint and resume is False:
            message = 'Found unexpected checkpoint when starting a new run.'
            raise RuntimeError(message)
        if checkpoint:
            saver.restore(sess, checkpoint)
Esempio n. 3
0
 def test_done_automatic(self):
     batch_env = self._create_test_batch_env((1, 2, 3, 4))
     algo = tools.MockAlgorithm(batch_env)
     done, _, _ = tools.simulate(batch_env, algo, log=False, reset=False)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         self.assertAllEqual([True, False, False, False], sess.run(done))
         self.assertAllEqual([True, True, False, False], sess.run(done))
         self.assertAllEqual([True, False, True, False], sess.run(done))
         self.assertAllEqual([True, True, False, True], sess.run(done))
Esempio n. 4
0
 def test_reset_automatic(self):
     batch_env = self._create_test_batch_env((1, 2, 3, 4))
     algo = tools.MockAlgorithm(batch_env)
     done, _, _ = tools.simulate(batch_env, algo, log=False, reset=False)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         for _ in range(10):
             sess.run(done)
     self.assertAllEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], batch_env[0].steps)
     self.assertAllEqual([2, 2, 2, 2, 2], batch_env[1].steps)
     self.assertAllEqual([3, 3, 3, 1], batch_env[2].steps)
     self.assertAllEqual([4, 4, 2], batch_env[3].steps)
Esempio n. 5
0
 def test_done_forced(self):
     reset = tf.placeholder_with_default(False, ())
     batch_env = self._create_test_batch_env((2, 4))
     algo = tools.MockAlgorithm(batch_env)
     done, _, _ = tools.simulate(batch_env, algo, False, reset)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         self.assertAllEqual([False, False], sess.run(done))
         self.assertAllEqual([False, False], sess.run(done, {reset: True}))
         self.assertAllEqual([True, False], sess.run(done))
         self.assertAllEqual([False, False], sess.run(done, {reset: True}))
         self.assertAllEqual([True, False], sess.run(done))
         self.assertAllEqual([False, False], sess.run(done))
         self.assertAllEqual([True, True], sess.run(done))
Esempio n. 6
0
 def test_not_done(self):
     step = tf.Variable(0, False, dtype=tf.int32, name='step')
     done = tf.equal((step + 1) % 2, 0)
     score = tf.cast(step, tf.float32)
     loop = tools.Loop(None, step)
     loop.add_phase('phase_1',
                    done,
                    score,
                    summary='',
                    steps=1,
                    report_every=3)
     # Score:  0 1 2 3 4 5 6 7 8
     # Done:     x   x   x   x
     # Report:     x     x     x
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         scores = list(loop.run(sess, saver=None, max_step=9))
     self.assertAllEqual([1, 4, 7], scores)
Esempio n. 7
0
 def test_report_every_step(self):
     step = tf.Variable(0, False, dtype=tf.int32, name='step')
     loop = tools.Loop(None, step)
     loop.add_phase('phase_1',
                    done=True,
                    score=0,
                    summary='',
                    steps=1,
                    report_every=3)
     # Step:   0 1 2 3 4 5 6 7 8
     # Report:     x     x     x
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         scores = loop.run(sess, saver=None, max_step=9)
         next(scores)
         self.assertEqual(3, sess.run(step))
         next(scores)
         self.assertEqual(6, sess.run(step))
         next(scores)
         self.assertEqual(9, sess.run(step))
Esempio n. 8
0
 def test_average_score_over_phases(self):
     loop = tools.Loop(None)
     loop.add_phase('phase_1',
                    done=True,
                    score=1,
                    summary='',
                    steps=1,
                    report_every=2)
     loop.add_phase('phase_2',
                    done=True,
                    score=2,
                    summary='',
                    steps=2,
                    report_every=5)
     # Score:    1 2 2 1 2 2 1 2 2 1 2 2 1 2 2 1 2
     # Report 1:       x           x           x
     # Report 2:               x             x
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         scores = list(loop.run(sess, saver=None, max_step=17))
     self.assertAllEqual([1, 2, 1, 2, 1], scores)
Esempio n. 9
0
 def test_not_done_batch(self):
     step = tf.Variable(0, False, dtype=tf.int32, name='step')
     done = tf.equal([step % 3, step % 4], 0)
     score = tf.cast([step, step**2], tf.float32)
     loop = tools.Loop(None, step)
     loop.add_phase('phase_1',
                    done,
                    score,
                    summary='',
                    steps=1,
                    report_every=8)
     # Step:    0  2  4  6
     # Score 1: 0  2  4  6
     # Done 1:  x        x
     # Score 2: 0  4 16 32
     # Done 2:  x     x
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         scores = list(loop.run(sess, saver=None, max_step=8))
         self.assertEqual(8, sess.run(step))
     self.assertAllEqual([(0 + 0 + 16 + 6) / 4], scores)
Esempio n. 10
0
 def _initialize_vars(self):
     self.sess.run(tf.global_variables_initializer())
     return