def test_mixed_batch(): """ Test a batch with a bunch of different environments. """ env_fns = [ lambda s=seed: SimpleEnv(s, (1, 2, 3), 'float32') for seed in [3, 3, 3, 3, 3, 3] ] #[5, 8, 1, 9, 3, 2]] make_agent = lambda: SimpleModel((1, 2, 3), stateful=True) for num_sub in [1, 2, 3]: batched_player = BatchedPlayer( batched_gym_env(env_fns, num_sub_batches=num_sub), make_agent(), 3) expected_eps = [] for player in [ BasicPlayer(env_fn(), make_agent(), 3) for env_fn in env_fns ]: transes = [t for _ in range(50) for t in player.play()] expected_eps.extend(_separate_episodes(transes)) actual_transes = [t for _ in range(50) for t in batched_player.play()] actual_eps = _separate_episodes(actual_transes) assert len(expected_eps) == len(actual_eps) for episode in expected_eps: found = False for i, actual in enumerate(actual_eps): if _episodes_equivalent(episode, actual): del actual_eps[i] found = True break assert found
def test_single_batch(): """ Test BatchedPlayer when the batch size is 1. """ make_env = lambda: SimpleEnv(9, (1, 2, 3), 'float32') make_agent = lambda: SimpleModel((1, 2, 3), stateful=True) basic_player = BasicPlayer(make_env(), make_agent(), 3) batched_player = BatchedPlayer(batched_gym_env([make_env]), make_agent(), 3) for _ in range(50): transes1 = basic_player.play() transes2 = batched_player.play() assert len(transes1) == len(transes2) for trans1, trans2 in zip(transes1, transes2): assert _transitions_equal(trans1, trans2)