def test_batch_replay_buffer(n_envs, n_steps, batch_size, maxlen): env = SyncBatchEnv([gym.make("CartPole-v0") for _ in range(n_envs)]) buffer = BatchReplayBuffer(maxlen, env) observations = env.reset() rewards, terminals = np.zeros(n_envs), np.zeros(n_envs) for _ in range(n_steps): actions = np.random.randint(env.action_space.n, size=n_envs) buffer.append(observations, actions, rewards, terminals) observations, rewards, terminals, _ = env.step(actions) assert len(buffer) == maxlen # check static dataset conversion dataset = buffer.to_mdp_dataset() transitions = [] for episode in dataset: transitions += episode.transitions assert len(transitions) >= len(buffer) observation_shape = env.observation_space.shape batch = buffer.sample(batch_size) assert len(batch) == batch_size assert batch.observations.shape == (batch_size, ) + observation_shape assert batch.actions.shape == (batch_size, ) assert batch.rewards.shape == (batch_size, 1) assert batch.next_observations.shape == (batch_size, ) + observation_shape assert batch.next_actions.shape == (batch_size, ) assert batch.next_rewards.shape == (batch_size, 1) assert batch.terminals.shape == (batch_size, 1) assert isinstance(batch.observations, np.ndarray) assert isinstance(batch.next_observations, np.ndarray)
def test_sync_batch_env_continuous(n_envs, n_steps): env = SyncBatchEnv([gym.make("Pendulum-v0") for _ in range(n_envs)]) observation_shape = env.observation_space.shape action_size = env.action_space.shape[0] observations = env.reset() assert observations.shape == (n_envs, ) + observation_shape for _ in range(n_steps): actions = np.random.random((n_envs, action_size)) observations, rewards, terminals, infos = env.step(actions) assert observations.shape == (n_envs, ) + observation_shape assert rewards.shape == (n_envs, ) assert terminals.shape == (n_envs, ) assert len(infos) == n_envs
def test_sync_batch_env_discrete(n_envs, n_steps): env = SyncBatchEnv([gym.make("CartPole-v0") for _ in range(n_envs)]) observation_shape = env.observation_space.shape action_size = env.action_space.n observations = env.reset() assert observations.shape == (n_envs, ) + observation_shape for _ in range(n_steps): actions = np.random.randint(action_size, size=n_envs) observations, rewards, terminals, infos = env.step(actions) assert observations.shape == (n_envs, ) + observation_shape assert rewards.shape == (n_envs, ) assert terminals.shape == (n_envs, ) assert len(infos) == n_envs