def test_multi_env(self): """ Test monitoring for concurrent environments. """ dirpath = tempfile.mkdtemp() try: log_file = os.path.join(dirpath, 'monitor.csv') env1 = LoggedEnv(SimpleEnv(2, (3,), 'float32'), log_file, use_locking=True) env2 = LoggedEnv(SimpleEnv(3, (3,), 'float32'), log_file, use_locking=True) env1.reset() env2.reset() for _ in range(13): for env in [env1, env2]: if env.step(env.action_space.sample())[2]: env.reset() env1.close() env2.close() with open(log_file, 'rt'): log_contents = pandas.read_csv(log_file) self.assertEqual(list(log_contents['r']), [2, 2.5, 2, 2.5, 2, 2, 2.5]) self.assertEqual(list(log_contents['l']), [3, 4, 3, 4, 3, 3, 4]) finally: shutil.rmtree(dirpath)
def test_multi_env(): """ Test monitoring for concurrent environments. """ with tempfile.TemporaryDirectory() as dirpath: log_file = os.path.join(dirpath, 'monitor.csv') env1 = LoggedEnv(SimpleEnv(2, (3, ), 'float32'), log_file, use_locking=True) env2 = LoggedEnv(SimpleEnv(3, (3, ), 'float32'), log_file, use_locking=True) env1.reset() env2.reset() for _ in range(13): for env in [env1, env2]: if env.step(env.action_space.sample())[2]: env.reset() env1.close() env2.close() with open(log_file, 'rt'): log_contents = pandas.read_csv(log_file) assert list(log_contents['r']) == [2, 2.5, 2, 2.5, 2, 2, 2.5] assert list(log_contents['l']) == [3, 4, 3, 4, 3, 3, 4]
def test_skip(): """ Test a FrameSkipEnv wrapper. """ # Timestep limit is 5. env = SimpleEnv(4, (3, 2, 5), 'float32') act1 = np.random.uniform(high=255.0, size=(3, 2, 5)) act2 = np.random.uniform(high=255.0, size=(3, 2, 5)) obs1 = env.reset() rew1 = 0.0 rew2 = 0.0 for _ in range(3): obs2, rew, _, _ = env.step(act1) rew1 += rew for _ in range(2): obs3, rew, done, _ = env.step(act2) rew2 += rew assert done env = FrameSkipEnv(env, num_frames=3) actual_obs1 = env.reset() assert np.allclose(actual_obs1, obs1) actual_obs2, actual_rew1, done, _ = env.step(act1) assert not done assert actual_rew1 == rew1 assert np.allclose(actual_obs2, obs2) actual_obs3, actual_rew2, done, _ = env.step(act2) assert done assert actual_rew2 == rew2 assert np.allclose(actual_obs3, obs3)
def test_truncation(stateful, state_tuple): """ Test sequence truncation for TruncatedRoller with a batch of one environment. """ env_fn = lambda: SimpleEnv(7, (5, 3), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps // 2 + 1) actual1 = trunc_roller.rollouts() assert actual1[-1].trunc_end actual2 = trunc_roller.rollouts() expected1, expected2 = _artificial_truncation(expected, len(actual1) - 1, actual1[-1].num_steps) assert len(actual2) == len(expected2) + 1 actual2 = actual2[:-1] _compare_rollout_batch(actual1, expected1) _compare_rollout_batch(actual2, expected2)
def _test_batch_equivalence_case(self, stateful, state_tuple): """ Test that doing things in batches is consistent, given the model parameters. """ env_fns = [ lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15) ] model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple) unbatched_rollouts = [] for env_fn in env_fns: batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): unbatched_rollouts.extend(trunc_roller.rollouts()) batched_rollouts = [] batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): batched_rollouts.extend(trunc_roller.rollouts()) _compare_rollout_batch(self, unbatched_rollouts, batched_rollouts, ordered=False)
def test_batched_stack(concat): """ Test that BatchedFrameStack is equivalent to a regular batched FrameStackEnv. """ envs = [ lambda idx=i: SimpleEnv(idx + 2, (3, 2, 5), 'float32') for i in range(6) ] env1 = BatchedFrameStack(batched_gym_env(envs, num_sub_batches=3, sync=True), concat=concat) env2 = batched_gym_env( [lambda env=e: FrameStackEnv(env(), concat=concat) for e in envs], num_sub_batches=3, sync=True) for j in range(50): for i in range(3): if j == 0 or (j + i) % 17 == 0: env1.reset_start(sub_batch=i) env2.reset_start(sub_batch=i) obs1 = env1.reset_wait(sub_batch=i) obs2 = env2.reset_wait(sub_batch=i) assert np.allclose(obs1, obs2) actions = [env1.action_space.sample() for _ in range(2)] env1.step_start(actions, sub_batch=i) env2.step_start(actions, sub_batch=i) obs1, rews1, dones1, _ = env1.step_wait(sub_batch=i) obs2, rews2, dones2, _ = env2.step_wait(sub_batch=i) assert np.allclose(obs1, obs2) assert np.array(rews1 == rews2).all() assert np.array(dones1 == dones2).all()
def _test_truncation_case(self, stateful, state_tuple): """ Test rollout truncation and continuation for a specific set of model parameters. """ env_fn = lambda: SimpleEnv(7, (5, 3), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps // 2 + 1) actual1 = trunc_roller.rollouts() self.assertTrue(actual1[-1].trunc_end) actual2 = trunc_roller.rollouts() expected1, expected2 = _artificial_truncation(expected, len(actual1) - 1, actual1[-1].num_steps) self.assertEqual(len(actual2), len(expected2) + 1) actual2 = actual2[:-1] _compare_rollout_batch(self, actual1, expected1) _compare_rollout_batch(self, actual2, expected2)
def _test_batch_equivalence_case(self, stateful, state_tuple, **roller_kwargs): """ Test BasicRoller equivalence when using a batch of environments. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') model = SimpleModel((4, 5), stateful=stateful, state_tuple=state_tuple) batched_env = batched_gym_env([env_fn] * 21, num_sub_batches=7, sync=True) ep_roller = EpisodeRoller(batched_env, model, **roller_kwargs) actual = ep_roller.rollouts() total_steps = sum([r.num_steps for r in actual]) self.assertTrue(len(actual) >= ep_roller.min_episodes) self.assertTrue(total_steps >= ep_roller.min_steps) if 'min_steps' not in roller_kwargs: num_eps = ep_roller.min_episodes + batched_env.num_envs - 1 self.assertTrue(len(actual) == num_eps) basic_roller = BasicRoller(env_fn(), model, min_episodes=len(actual)) expected = basic_roller.rollouts() _compare_rollout_batch(self, actual, expected)
def test_resize_even(self): """ Test resizing for an even number of pixels. """ env = SimpleEnv(5, (13, 5, 3), 'float32') frame = env.reset() actual = ResizeImageEnv(env, size=(5, 4)).reset() expected = tf.Session().run( tf.image.resize_images(frame, [5, 4], method=tf.image.ResizeMethod.AREA)) self.assertEqual(actual.shape, (5, 4, 3)) self.assertTrue(np.allclose(actual, expected))
def test_logged_single_env(): """ Test LoggedEnv for a single environment. """ with tempfile.TemporaryDirectory() as dirpath: log_file = os.path.join(dirpath, 'monitor.csv') env = LoggedEnv(SimpleEnv(2, (3, ), 'float32'), log_file) for _ in range(4): env.reset() while not env.step(env.action_space.sample())[2]: pass env.close() with open(log_file, 'rt'): log_contents = pandas.read_csv(log_file) assert list(log_contents['r']) == [2] * 4 assert list(log_contents['l']) == [3] * 4
def test_ep_basic_equivalence(stateful, state_tuple, limits): """ Test that EpisodeRoller is equivalent to a BasicRoller when run on a single environment. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, **limits) expected = basic_roller.rollouts() batched_env = batched_gym_env([env_fn], sync=True) ep_roller = EpisodeRoller(batched_env, model, **limits) actual = ep_roller.rollouts() _compare_rollout_batch(actual, expected)
def _test_basic_equivalence_case(self, stateful, state_tuple, **roller_kwargs): """ Test BasicRoller equivalence for a single env in a specific case. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, **roller_kwargs) expected = basic_roller.rollouts() batched_env = batched_gym_env([env_fn], sync=True) ep_roller = EpisodeRoller(batched_env, model, **roller_kwargs) actual = ep_roller.rollouts() _compare_rollout_batch(self, actual, expected)
def _test_basic_equivalence_case(self, stateful, state_tuple): """ Test BasicRoller equivalence for a specific set of model settings. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps) actual = trunc_roller.rollouts() _compare_rollout_batch(self, actual, expected)
def test_rewards(self): """ Test that rewards are masked properly. """ real_env = SimpleEnv(1, (3, 4), 'uint8') env = RL2Env(real_env, real_env.action_space.sample(), num_eps=5, warmup_eps=-2) env.reset() done_eps = 0 while done_eps < 3: obs, rew, _, _ = env.step(env.action_space.sample()) self.assertEqual(rew, 0) if obs[3]: done_eps += 1 while True: _, rew, done, _ = env.step(env.action_space.sample()) self.assertNotEqual(rew, 0) if done: break
def test_single_env(self): """ Test monitoring for a single environment. """ dirpath = tempfile.mkdtemp() try: log_file = os.path.join(dirpath, 'monitor.csv') env = LoggedEnv(SimpleEnv(2, (3,), 'float32'), log_file) for _ in range(4): env.reset() while not env.step(env.action_space.sample())[2]: pass env.close() with open(log_file, 'rt'): log_contents = pandas.read_csv(log_file) self.assertEqual(list(log_contents['r']), [2] * 4) self.assertEqual(list(log_contents['l']), [3] * 4) finally: shutil.rmtree(dirpath)
def test_trunc_basic_equivalence(stateful, state_tuple): """ Test that TruncatedRoller is equivalent to BasicRoller for batches of one environment when the episodes end cleanly. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps) actual = trunc_roller.rollouts() _compare_rollout_batch(actual, expected)
def test_num_eps(self): """ Test that the meta-episode contains the right number of sub-episodes. """ real_env = SimpleEnv(1, (3, 4), 'uint8') env = RL2Env(real_env, real_env.action_space.sample(), num_eps=3, warmup_eps=1) done_eps = 0 env.reset() while done_eps < 3: obs, _, done, _ = env.step(env.action_space.sample()) if obs[3]: done_eps += 1 if done_eps == 3: self.assertTrue(done) break else: self.assertFalse(done)
def test_multiple_batches(self): """ Make sure calling rollouts multiple times works. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() try: model = SimpleModel(env.action_space.low.shape) finally: env.close() batched_env = batched_gym_env([env_fn], sync=True) try: ep_roller = EpisodeRoller(batched_env, model, min_episodes=5, min_steps=7) first = ep_roller.rollouts() for _ in range(3): _compare_rollout_batch(self, first, ep_roller.rollouts()) finally: batched_env.close()
def test_rl2_num_eps(): """ Test that RL^2 meta-episodes contain the right number of sub-episodes. """ real_env = SimpleEnv(1, (3, 4), 'uint8') env = RL2Env(real_env, real_env.action_space.sample(), num_eps=3, warmup_eps=1) done_eps = 0 env.reset() while done_eps < 3: obs, _, done, _ = env.step(env.action_space.sample()) if obs[3]: done_eps += 1 if done_eps == 3: assert done break else: assert not done
def test_max_2(self): """ Test maxing 2 frames. """ env = SimpleEnv(5, (3, 2, 5), 'float32') actions = [env.action_space.sample() for _ in range(4)] frame1 = env.reset() frame2 = env.step(actions[0])[0] frame3 = env.step(actions[1])[0] frame4 = env.step(actions[2])[0] frame5 = env.step(actions[3])[0] wrapped = MaxEnv(env, num_images=2) max1 = wrapped.reset() max2 = wrapped.step(actions[0])[0] max3 = wrapped.step(actions[1])[0] max4 = wrapped.step(actions[2])[0] max5 = wrapped.step(actions[3])[0] self.assertTrue((max1 == frame1).all()) self.assertTrue((max2 == np.max([frame1, frame2], axis=0)).all()) self.assertTrue((max3 == np.max([frame2, frame3], axis=0)).all()) self.assertTrue((max4 == np.max([frame3, frame4], axis=0)).all()) self.assertTrue((max5 == np.max([frame4, frame5], axis=0)).all())
def test_ep_batches(stateful, state_tuple, limits): """ Test that EpisodeRoller is equivalent to a BasicRoller when run on a batch of envs. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') model = SimpleModel((4, 5), stateful=stateful, state_tuple=state_tuple) batched_env = batched_gym_env([env_fn] * 21, num_sub_batches=7, sync=True) ep_roller = EpisodeRoller(batched_env, model, **limits) actual = ep_roller.rollouts() total_steps = sum([r.num_steps for r in actual]) assert len(actual) >= ep_roller.min_episodes assert total_steps >= ep_roller.min_steps if 'min_steps' not in limits: num_eps = ep_roller.min_episodes + batched_env.num_envs - 1 assert len(actual) == num_eps basic_roller = BasicRoller(env_fn(), model, min_episodes=len(actual)) expected = basic_roller.rollouts() _compare_rollout_batch(actual, expected)
def test_trunc_drop_states(): """ Test TruncatedRoller with drop_states=True. """ env_fns = [ lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15) ] model = SimpleModel((5, 3), stateful=True, state_tuple=True) expected_rollouts = [] batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): expected_rollouts.extend(trunc_roller.rollouts()) for rollout in expected_rollouts: for model_out in rollout.model_outs: model_out['states'] = None actual_rollouts = [] trunc_roller = TruncatedRoller(batched_env, model, 17, drop_states=True) for _ in range(3): actual_rollouts.extend(trunc_roller.rollouts()) _compare_rollout_batch(actual_rollouts, expected_rollouts)
def test_trunc_batches(stateful, state_tuple): """ Test that TruncatedRoller produces the same result for batches as it does for individual environments. """ env_fns = [ lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15) ] model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple) unbatched_rollouts = [] for env_fn in env_fns: batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): unbatched_rollouts.extend(trunc_roller.rollouts()) batched_rollouts = [] batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): batched_rollouts.extend(trunc_roller.rollouts()) _compare_rollout_batch(unbatched_rollouts, batched_rollouts, ordered=False)
def env_fn(): return SimpleEnv(3, (4, 5), 'uint8')
def make_fn(seed): """ Get an environment constructor with a seed. """ return lambda: SimpleEnv(seed, SHAPE, dtype)
def env_fn(): return SimpleEnv(7, (5, 3), 'uint8')