示例#1
0
class TestHierReplayBuffer(unittest.TestCase):
    """Tests for the HierReplayBuffer object."""
    def setUp(self):
        self.replay_buffer = HierReplayBuffer(
            buffer_size=2,
            batch_size=1,
            meta_period=3,
            obs_dim=1,
            ac_dim=1,
            co_dim=1,
            goal_dim=1,
            num_levels=3,
        )

    def tearDown(self):
        del self.replay_buffer

    def test_buffer_size(self):
        """Validate the buffer_size output from the replay buffer."""
        self.assertEqual(self.replay_buffer.buffer_size, 2)

    def test_add_sample(self):
        """Test the `add` and `sample` methods the replay buffer."""
        # Set the random seed.
        random.seed(0)

        obs_t = [
            np.array([0]),
            np.array([1]),
            np.array([2]),
            np.array([3]),
            np.array([4]),
            np.array([5]),
            np.array([6]),
            np.array([7]),
            np.array([8]),
            np.array([9])
        ]
        action_t = [[
            np.array([0]),
            np.array([1]),
            np.array([2]),
            np.array([3])
        ],
                    [
                        np.array([0]),
                        np.array([1]),
                        np.array([2]),
                        np.array([3]),
                        np.array([4]),
                        np.array([5]),
                        np.array([6]),
                        np.array([7]),
                        np.array([8]),
                        np.array([9])
                    ],
                    [
                        np.array([0]),
                        np.array([1]),
                        np.array([2]),
                        np.array([3]),
                        np.array([4]),
                        np.array([5]),
                        np.array([6]),
                        np.array([7]),
                        np.array([8]),
                        np.array([9])
                    ]]
        context_t = [np.array([0]), np.array([1])]
        reward_t = [[0], [0, 1, 2], [0, 1, 2, 3, 4, 5, 6, 7, 8]]
        done_t = [
            False, False, False, False, False, False, False, False, False
        ]

        # Add an element.
        self.replay_buffer.add(
            obs_t=obs_t,
            action_t=action_t,
            context_t=context_t,
            reward_t=reward_t,
            done_t=done_t,
        )

        # Check is_full in the False case.
        self.assertEqual(self.replay_buffer.is_full(), False)

        # Add an element.
        self.replay_buffer.add(
            obs_t=obs_t,
            action_t=action_t,
            context_t=context_t,
            reward_t=reward_t,
            done_t=done_t,
        )

        # Check is_full in the True case.
        self.assertEqual(self.replay_buffer.is_full(), True)

        # Check can_sample in the True case.
        self.assertEqual(self.replay_buffer.can_sample(), True)

        # Test the `sample` method.
        obs0, obs1, act, rew, done, _ = self.replay_buffer.sample(False)
        np.testing.assert_array_almost_equal(obs0[0], [[0, 0]])
        np.testing.assert_array_almost_equal(obs0[1], [[6, 2]])
        np.testing.assert_array_almost_equal(obs0[2], [[6, 6]])

        np.testing.assert_array_almost_equal(obs1[0], [[9, 1]])
        np.testing.assert_array_almost_equal(obs1[1], [[9, 3]])
        np.testing.assert_array_almost_equal(obs1[2], [[7, 7]])

        np.testing.assert_array_almost_equal(act[0], [[0]])
        np.testing.assert_array_almost_equal(act[1], [[6]])
        np.testing.assert_array_almost_equal(act[2], [[6]])

        np.testing.assert_array_almost_equal(rew[0], [0])
        np.testing.assert_array_almost_equal(rew[1], [2])
        np.testing.assert_array_almost_equal(rew[2], [6])

        np.testing.assert_array_almost_equal(done[0], [0])
        np.testing.assert_array_almost_equal(done[1], [0])
        np.testing.assert_array_almost_equal(done[2], [0])
class TestHierReplayBuffer(unittest.TestCase):
    """Tests for the HierReplayBuffer object."""
    def setUp(self):
        self.replay_buffer = HierReplayBuffer(buffer_size=2,
                                              batch_size=1,
                                              meta_period=1,
                                              meta_obs_dim=2,
                                              meta_ac_dim=3,
                                              worker_obs_dim=4,
                                              worker_ac_dim=5)

    def tearDown(self):
        del self.replay_buffer

    def test_init(self):
        """Validate that all the attributes were initialize properly."""
        self.assertTupleEqual(self.replay_buffer.meta_obs0.shape, (1, 2))
        self.assertTupleEqual(self.replay_buffer.meta_obs1.shape, (1, 2))
        self.assertTupleEqual(self.replay_buffer.meta_act.shape, (1, 3))
        self.assertTupleEqual(self.replay_buffer.meta_rew.shape, (1, ))
        self.assertTupleEqual(self.replay_buffer.meta_done.shape, (1, ))
        self.assertTupleEqual(self.replay_buffer.worker_obs0.shape, (1, 4))
        self.assertTupleEqual(self.replay_buffer.worker_obs1.shape, (1, 4))
        self.assertTupleEqual(self.replay_buffer.worker_act.shape, (1, 5))
        self.assertTupleEqual(self.replay_buffer.worker_rew.shape, (1, ))
        self.assertTupleEqual(self.replay_buffer.worker_done.shape, (1, ))

    def test_buffer_size(self):
        """Validate the buffer_size output from the replay buffer."""
        self.assertEqual(self.replay_buffer.buffer_size, 2)

    def test_add_sample(self):
        """Test the `add` and `sample` methods the replay buffer."""
        """Test the `add` and `sample` methods the replay buffer."""
        # Add an element.
        self.replay_buffer.add(
            obs_t=[np.array([0, 0, 0, 0]),
                   np.array([1, 1, 1, 1])],
            goal_t=np.array([2, 2, 2]),
            action_t=[np.array([3, 3, 3, 3, 3])],
            reward_t=[4],
            done=[False],
            meta_obs_t=(np.array([5, 5]), np.array([6, 6])),
            meta_reward_t=7,
        )

        # Check is_full in the False case.
        self.assertEqual(self.replay_buffer.is_full(), False)

        # Add an element.
        self.replay_buffer.add(
            obs_t=[np.array([0, 0, 0, 0]),
                   np.array([1, 1, 1, 1])],
            goal_t=np.array([2, 2, 2]),
            action_t=[np.array([3, 3, 3, 3, 3])],
            reward_t=[4],
            done=[False],
            meta_obs_t=(np.array([5, 5]), np.array([6, 6])),
            meta_reward_t=7,
        )

        # Check is_full in the True case.
        self.assertEqual(self.replay_buffer.is_full(), True)

        # Check can_sample in the True case.
        self.assertEqual(self.replay_buffer.can_sample(), True)

        # Test the `sample` method.
        meta_obs0, meta_obs1, meta_act, meta_rew, meta_done, worker_obs0, \
            worker_obs1, worker_act, worker_rew, worker_done, _ = \
            self.replay_buffer.sample()
        np.testing.assert_array_almost_equal(meta_obs0, [[5, 5]])
        np.testing.assert_array_almost_equal(meta_obs1, [[6, 6]])
        np.testing.assert_array_almost_equal(meta_act, [[2, 2, 2]])
        np.testing.assert_array_almost_equal(meta_rew, [7])
        np.testing.assert_array_almost_equal(meta_done, [0])
        np.testing.assert_array_almost_equal(worker_obs0, [[0, 0, 0, 0]])
        np.testing.assert_array_almost_equal(worker_obs1, [[1, 1, 1, 1]])
        np.testing.assert_array_almost_equal(worker_act, [[3, 3, 3, 3, 3]])
        np.testing.assert_array_almost_equal(worker_rew, [4])
        np.testing.assert_array_almost_equal(worker_done, [0])