class TestHerReplayBuffer: def setup_method(self): self.env = DummyDictEnv() self.obs = self.env.reset() self.replay_buffer = HerReplayBuffer( env_spec=self.env.spec, size_in_transitions=3, time_horizon=1, replay_k=0.4, reward_fun=self.env.compute_reward) def _add_single_transition(self): self.replay_buffer.add_transition( observation=self.obs, action=self.env.action_space.sample(), terminal=False, next_observation=self.obs) def _add_transitions(self): self.replay_buffer.add_transitions( observation=[self.obs], action=[self.env.action_space.sample()], terminal=[False], next_observation=[self.obs]) def test_add_transition_dtype(self): self._add_single_transition() sample = self.replay_buffer.sample(1) assert sample['observation'].dtype == self.env.observation_space[ 'observation'].dtype assert sample['achieved_goal'].dtype == self.env.observation_space[ 'achieved_goal'].dtype assert sample['goal'].dtype == self.env.observation_space[ 'desired_goal'].dtype assert sample['action'].dtype == self.env.action_space.dtype def test_add_transitions_dtype(self): self._add_transitions() sample = self.replay_buffer.sample(1) assert sample['observation'].dtype == self.env.observation_space[ 'observation'].dtype assert sample['achieved_goal'].dtype == self.env.observation_space[ 'achieved_goal'].dtype assert sample['goal'].dtype == self.env.observation_space[ 'desired_goal'].dtype assert sample['action'].dtype == self.env.action_space.dtype def test_eviction_policy(self): self.replay_buffer.add_transitions( observation=[self.obs, self.obs], next_observation=[self.obs, self.obs], terminal=[False, False], action=[1, 2]) assert not self.replay_buffer.full self.replay_buffer.add_transitions( observation=[self.obs, self.obs], next_observation=[self.obs, self.obs], terminal=[False, False], action=[3, 4]) assert self.replay_buffer.full self.replay_buffer.add_transitions( observation=[self.obs, self.obs], next_observation=[self.obs, self.obs], terminal=[False, False], action=[5, 6]) self.replay_buffer.add_transitions( observation=[self.obs, self.obs], next_observation=[self.obs, self.obs], terminal=[False, False], action=[7, 8]) assert np.array_equal(self.replay_buffer._buffer['action'], [[7], [8], [6]]) assert self.replay_buffer.n_transitions_stored == 3 def test_pickleable(self): self._add_transitions() replay_buffer_pickled = pickle.loads(pickle.dumps(self.replay_buffer)) assert replay_buffer_pickled._buffer.keys( ) == self.replay_buffer._buffer.keys() for k in replay_buffer_pickled._buffer: assert replay_buffer_pickled._buffer[ k].shape == self.replay_buffer._buffer[k].shape sample = self.replay_buffer.sample(1) sample2 = replay_buffer_pickled.sample(1) for k in self.replay_buffer._buffer: assert sample[k].shape == sample2[k].shape
class TestHerReplayBuffer: def setup_method(self): self.env = DummyDictEnv() self.obs = self.env.reset() self._replay_k = 4 self.replay_buffer = HERReplayBuffer(env_spec=self.env.spec, capacity_in_transitions=10, replay_k=self._replay_k, reward_fn=self.env.compute_reward) def test_replay_k(self): self.replay_buffer = HERReplayBuffer(env_spec=self.env.spec, capacity_in_transitions=10, replay_k=0, reward_fn=self.env.compute_reward) with pytest.raises(ValueError): self.replay_buffer = HERReplayBuffer( env_spec=self.env.spec, capacity_in_transitions=10, replay_k=0.2, reward_fn=self.env.compute_reward) def _add_one_path(self): path = dict( observations=np.asarray([self.obs, self.obs]), actions=np.asarray([ self.env.action_space.sample(), self.env.action_space.sample() ]), rewards=np.asarray([[1], [1]]), terminals=np.asarray([[False], [False]]), next_observations=np.asarray([self.obs, self.obs]), ) self.replay_buffer.add_path(path) def test_add_path(self): self._add_one_path() # HER buffer should add replay_k + 1 transitions to the buffer # for each transition in the given path. This doesn't apply to # the last transition, where only that transition gets added. path_len = 2 total_expected_transitions = sum( [self._replay_k + 1 for _ in range(path_len - 1)]) + 1 assert (self.replay_buffer.n_transitions_stored == total_expected_transitions) assert (len( self.replay_buffer._path_segments) == total_expected_transitions - 1) # check that buffer has the correct keys assert { 'observations', 'next_observations', 'actions', 'rewards', 'terminals' } <= set(self.replay_buffer._buffer) # check that dict obses are flattened obs = self.replay_buffer._buffer['observations'][0] next_obs = self.replay_buffer._buffer['next_observations'][0] assert obs.shape == self.env.spec.observation_space.flat_dim assert next_obs.shape == self.env.spec.observation_space.flat_dim def test_pickleable(self): self._add_one_path() replay_buffer_pickled = pickle.loads(pickle.dumps(self.replay_buffer)) assert (replay_buffer_pickled._buffer.keys() == self.replay_buffer._buffer.keys()) for k in replay_buffer_pickled._buffer: assert replay_buffer_pickled._buffer[ k].shape == self.replay_buffer._buffer[k].shape sample = self.replay_buffer.sample_transitions(1) sample2 = replay_buffer_pickled.sample_transitions(1) for k in sample.keys(): assert sample[k].shape == sample2[k].shape assert len(sample) == len(sample2)