def test_done_automatic(self): batch_env = self._create_test_batch_env((1, 2, 3, 4)) algo = tools.MockAlgorithm(batch_env) done, _, _ = tools.simulate(batch_env, algo, log=False, reset=False) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) self.assertAllEqual([True, False, False, False], sess.run(done)) self.assertAllEqual([True, True, False, False], sess.run(done)) self.assertAllEqual([True, False, True, False], sess.run(done)) self.assertAllEqual([True, True, False, True], sess.run(done))
def test_reset_automatic(self): batch_env = self._create_test_batch_env((1, 2, 3, 4)) algo = tools.MockAlgorithm(batch_env) done, _, _ = tools.simulate(batch_env, algo, log=False, reset=False) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(10): sess.run(done) self.assertAllEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], batch_env[0].steps) self.assertAllEqual([2, 2, 2, 2, 2], batch_env[1].steps) self.assertAllEqual([3, 3, 3, 1], batch_env[2].steps) self.assertAllEqual([4, 4, 2], batch_env[3].steps)
def test_done_forced(self): reset = tf.placeholder_with_default(False, ()) batch_env = self._create_test_batch_env((2, 4)) algo = tools.MockAlgorithm(batch_env) done, _, _ = tools.simulate(batch_env, algo, False, reset) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) self.assertAllEqual([False, False], sess.run(done)) self.assertAllEqual([False, False], sess.run(done, {reset: True})) self.assertAllEqual([True, False], sess.run(done)) self.assertAllEqual([False, False], sess.run(done, {reset: True})) self.assertAllEqual([True, False], sess.run(done)) self.assertAllEqual([False, False], sess.run(done)) self.assertAllEqual([True, True], sess.run(done))