def test_from_time_step_list_batch(batch_data): batches = [batch_data, batch_data] s = TimeStepBatch.from_time_step_list(batch_data['env_spec'], batches) new_obs = np.concatenate( [batch_data['observations'], batch_data['observations']]) new_next_obs = np.concatenate( [batch_data['next_observations'], batch_data['next_observations']]) new_actions = np.concatenate( [batch_data['actions'], batch_data['actions']]) new_rewards = np.concatenate( [batch_data['rewards'], batch_data['rewards']]) new_step_types = np.concatenate( [batch_data['step_types'], batch_data['step_types']]) new_env_infos = { k: np.concatenate([b['env_infos'][k] for b in batches]) for k in batches[0]['env_infos'].keys() } new_agent_infos = { k: np.concatenate([b['agent_infos'][k] for b in batches]) for k in batches[0]['agent_infos'].keys() } assert s.env_spec == batch_data['env_spec'] assert np.array_equal(s.observations, new_obs) assert np.array_equal(s.next_observations, new_next_obs) assert np.array_equal(s.actions, new_actions) assert np.array_equal(s.rewards, new_rewards) assert np.array_equal(s.step_types, new_step_types) for key in new_env_infos: assert key in s.env_infos assert np.array_equal(new_env_infos[key], s.env_infos[key]) for key in new_agent_infos: assert key in s.agent_infos assert np.array_equal(new_agent_infos[key], s.agent_infos[key])
def test_from_empty_time_step_list_batch(batch_data): with pytest.raises(ValueError, match='at least one dict'): batches = [] s = TimeStepBatch.from_time_step_list(batch_data['env_spec'], batches) del s