class TestHierReplayBuffer(unittest.TestCase): """Tests for the HierReplayBuffer object.""" def setUp(self): self.replay_buffer = HierReplayBuffer( buffer_size=2, batch_size=1, meta_period=3, obs_dim=1, ac_dim=1, co_dim=1, goal_dim=1, num_levels=3, ) def tearDown(self): del self.replay_buffer def test_buffer_size(self): """Validate the buffer_size output from the replay buffer.""" self.assertEqual(self.replay_buffer.buffer_size, 2) def test_add_sample(self): """Test the `add` and `sample` methods the replay buffer.""" # Set the random seed. random.seed(0) obs_t = [ np.array([0]), np.array([1]), np.array([2]), np.array([3]), np.array([4]), np.array([5]), np.array([6]), np.array([7]), np.array([8]), np.array([9]) ] action_t = [[ np.array([0]), np.array([1]), np.array([2]), np.array([3]) ], [ np.array([0]), np.array([1]), np.array([2]), np.array([3]), np.array([4]), np.array([5]), np.array([6]), np.array([7]), np.array([8]), np.array([9]) ], [ np.array([0]), np.array([1]), np.array([2]), np.array([3]), np.array([4]), np.array([5]), np.array([6]), np.array([7]), np.array([8]), np.array([9]) ]] context_t = [np.array([0]), np.array([1])] reward_t = [[0], [0, 1, 2], [0, 1, 2, 3, 4, 5, 6, 7, 8]] done_t = [ False, False, False, False, False, False, False, False, False ] # Add an element. self.replay_buffer.add( obs_t=obs_t, action_t=action_t, context_t=context_t, reward_t=reward_t, done_t=done_t, ) # Check is_full in the False case. self.assertEqual(self.replay_buffer.is_full(), False) # Add an element. self.replay_buffer.add( obs_t=obs_t, action_t=action_t, context_t=context_t, reward_t=reward_t, done_t=done_t, ) # Check is_full in the True case. self.assertEqual(self.replay_buffer.is_full(), True) # Check can_sample in the True case. self.assertEqual(self.replay_buffer.can_sample(), True) # Test the `sample` method. obs0, obs1, act, rew, done, _ = self.replay_buffer.sample(False) np.testing.assert_array_almost_equal(obs0[0], [[0, 0]]) np.testing.assert_array_almost_equal(obs0[1], [[6, 2]]) np.testing.assert_array_almost_equal(obs0[2], [[6, 6]]) np.testing.assert_array_almost_equal(obs1[0], [[9, 1]]) np.testing.assert_array_almost_equal(obs1[1], [[9, 3]]) np.testing.assert_array_almost_equal(obs1[2], [[7, 7]]) np.testing.assert_array_almost_equal(act[0], [[0]]) np.testing.assert_array_almost_equal(act[1], [[6]]) np.testing.assert_array_almost_equal(act[2], [[6]]) np.testing.assert_array_almost_equal(rew[0], [0]) np.testing.assert_array_almost_equal(rew[1], [2]) np.testing.assert_array_almost_equal(rew[2], [6]) np.testing.assert_array_almost_equal(done[0], [0]) np.testing.assert_array_almost_equal(done[1], [0]) np.testing.assert_array_almost_equal(done[2], [0])
class TestHierReplayBuffer(unittest.TestCase): """Tests for the HierReplayBuffer object.""" def setUp(self): self.replay_buffer = HierReplayBuffer(buffer_size=2, batch_size=1, meta_period=1, meta_obs_dim=2, meta_ac_dim=3, worker_obs_dim=4, worker_ac_dim=5) def tearDown(self): del self.replay_buffer def test_init(self): """Validate that all the attributes were initialize properly.""" self.assertTupleEqual(self.replay_buffer.meta_obs0.shape, (1, 2)) self.assertTupleEqual(self.replay_buffer.meta_obs1.shape, (1, 2)) self.assertTupleEqual(self.replay_buffer.meta_act.shape, (1, 3)) self.assertTupleEqual(self.replay_buffer.meta_rew.shape, (1, )) self.assertTupleEqual(self.replay_buffer.meta_done.shape, (1, )) self.assertTupleEqual(self.replay_buffer.worker_obs0.shape, (1, 4)) self.assertTupleEqual(self.replay_buffer.worker_obs1.shape, (1, 4)) self.assertTupleEqual(self.replay_buffer.worker_act.shape, (1, 5)) self.assertTupleEqual(self.replay_buffer.worker_rew.shape, (1, )) self.assertTupleEqual(self.replay_buffer.worker_done.shape, (1, )) def test_buffer_size(self): """Validate the buffer_size output from the replay buffer.""" self.assertEqual(self.replay_buffer.buffer_size, 2) def test_add_sample(self): """Test the `add` and `sample` methods the replay buffer.""" """Test the `add` and `sample` methods the replay buffer.""" # Add an element. self.replay_buffer.add( obs_t=[np.array([0, 0, 0, 0]), np.array([1, 1, 1, 1])], goal_t=np.array([2, 2, 2]), action_t=[np.array([3, 3, 3, 3, 3])], reward_t=[4], done=[False], meta_obs_t=(np.array([5, 5]), np.array([6, 6])), meta_reward_t=7, ) # Check is_full in the False case. self.assertEqual(self.replay_buffer.is_full(), False) # Add an element. self.replay_buffer.add( obs_t=[np.array([0, 0, 0, 0]), np.array([1, 1, 1, 1])], goal_t=np.array([2, 2, 2]), action_t=[np.array([3, 3, 3, 3, 3])], reward_t=[4], done=[False], meta_obs_t=(np.array([5, 5]), np.array([6, 6])), meta_reward_t=7, ) # Check is_full in the True case. self.assertEqual(self.replay_buffer.is_full(), True) # Check can_sample in the True case. self.assertEqual(self.replay_buffer.can_sample(), True) # Test the `sample` method. meta_obs0, meta_obs1, meta_act, meta_rew, meta_done, worker_obs0, \ worker_obs1, worker_act, worker_rew, worker_done, _ = \ self.replay_buffer.sample() np.testing.assert_array_almost_equal(meta_obs0, [[5, 5]]) np.testing.assert_array_almost_equal(meta_obs1, [[6, 6]]) np.testing.assert_array_almost_equal(meta_act, [[2, 2, 2]]) np.testing.assert_array_almost_equal(meta_rew, [7]) np.testing.assert_array_almost_equal(meta_done, [0]) np.testing.assert_array_almost_equal(worker_obs0, [[0, 0, 0, 0]]) np.testing.assert_array_almost_equal(worker_obs1, [[1, 1, 1, 1]]) np.testing.assert_array_almost_equal(worker_act, [[3, 3, 3, 3, 3]]) np.testing.assert_array_almost_equal(worker_rew, [4]) np.testing.assert_array_almost_equal(worker_done, [0])