Пример #1
0
    def _check(num_transitions, batch_size, permute):
        dummy = np.zeros((num_transitions, 1))
        input_obs = np.arange(num_transitions)[:, None]
        transitions = TransitionBatch(input_obs, dummy, input_obs + 1, dummy, dummy)
        it = replay_buffer.BootstrapIterator(
            transitions, batch_size, num_members, permute_indices=permute
        )

        member_contents = [[] for _ in range(num_members)]
        for batch in it:
            obs, *_ = batch.astuple()
            assert obs.shape[0] == num_members
            assert obs.shape[2] == 1
            for i in range(num_members):
                member_contents[i].extend(obs[i].squeeze(1).tolist())

        all_elements = list(range(num_transitions))
        for i in range(num_members):

            if permute:
                # this checks that all elements are present but shuffled
                sorted_content = sorted(member_contents[i])
                assert sorted_content == all_elements
                assert member_contents[i] != all_elements
            else:
                # check that it did sampling with replacement
                assert len(member_contents[i]) == num_transitions
                assert min(member_contents[i]) >= 0
                assert max(member_contents[i]) < num_transitions
                assert member_contents[i] != all_elements

            # this checks that all member samples are different
            for j in range(i + 1, num_members):
                assert member_contents[i] != member_contents[j]
Пример #2
0
    def _check(num_transitions, batch_size):
        dummy = np.zeros((num_transitions, 1))
        input_obs = np.arange(num_transitions)[:, None]
        transitions = TransitionBatch(input_obs, dummy, dummy, dummy, dummy)
        it = replay_buffer.TransitionIterator(
            transitions, batch_size, shuffle_each_epoch=True
        )

        all_obs = []
        for i, batch in enumerate(it):
            obs, *_ = batch.astuple()
            for j in range(len(obs)):
                all_obs.append(obs[j].item())
        all_obs_sorted = sorted(all_obs)

        assert any([a != b for a, b in zip(all_obs, all_obs_sorted)])
        assert all([a == b for a, b in zip(all_obs_sorted, range(num_transitions))])

        # the second time the order should be different
        all_obs_second = []
        for i, batch in enumerate(it):
            obs, *_ = batch.astuple()
            for j in range(len(obs)):
                all_obs_second.append(obs[j].item())
        assert any([a != b for a, b in zip(all_obs, all_obs_second)])
Пример #3
0
def test_transition_batch_getitem():
    how_many = 10
    obs = np.random.randn(how_many, 4)
    act = np.random.randn(how_many, 2)
    next_obs = np.random.randn(how_many, 4)
    rewards = np.random.randn(how_many, 1)
    dones = np.random.randn(how_many, 1)

    transitions = TransitionBatch(obs, act, next_obs, rewards, dones)
    for i in range(how_many):
        o, a, no, r, d = transitions[i].astuple()
        assert np.allclose(o, obs[i])
        assert np.allclose(a, act[i])
        assert np.allclose(no, next_obs[i])
        assert np.allclose(r, rewards[i])
        assert np.allclose(d, dones[i])

        o, a, no, r, d = transitions[i:].astuple()
        assert np.allclose(o, obs[i:])

        o, a, no, r, d = transitions[:i].astuple()
        assert np.allclose(o, obs[:i])

        for j in range(i + 1, how_many):
            o, a, no, r, d = transitions[i:j].astuple()
            assert np.allclose(o, obs[i:j])

    for sz in range(1, how_many):
        indices = np.random.choice(how_many, size=5)
        o, a, no, r, d = transitions[indices].astuple()
        assert np.allclose(o, obs[indices])
Пример #4
0
    def _batch_from_indices(self, indices: Sized) -> TransitionBatch:
        obs = self.obs[indices]
        next_obs = self.next_obs[indices]
        action = self.action[indices]
        reward = self.reward[indices]
        done = self.done[indices]

        return TransitionBatch(obs, action, next_obs, reward, done)
Пример #5
0
def _consolidate_batches(batches: Sequence[TransitionBatch]) -> TransitionBatch:
    len_batches = len(batches)
    b0 = batches[0]
    obs = np.empty((len_batches,) + b0.obs.shape, dtype=b0.obs.dtype)
    act = np.empty((len_batches,) + b0.act.shape, dtype=b0.act.dtype)
    next_obs = np.empty((len_batches,) + b0.obs.shape, dtype=b0.obs.dtype)
    rewards = np.empty((len_batches,) + b0.rewards.shape, dtype=np.float32)
    dones = np.empty((len_batches,) + b0.dones.shape, dtype=bool)
    for i, b in enumerate(batches):
        obs[i] = b.obs
        act[i] = b.act
        next_obs[i] = b.next_obs
        rewards[i] = b.rewards
        dones[i] = b.dones
    return TransitionBatch(obs, act, next_obs, rewards, dones)
Пример #6
0
 def _check_for_capacity_and_batch_size(num_transitions, batch_size):
     dummy = np.zeros((num_transitions, 1))
     input_obs = np.arange(num_transitions)[:, None]
     transitions = TransitionBatch(input_obs, dummy, input_obs + 1, dummy, dummy)
     it = replay_buffer.TransitionIterator(transitions, batch_size)
     assert len(it) == int(np.ceil(num_transitions / bs))
     idx_check = 0
     for i, batch in enumerate(it):
         obs, action, next_obs, reward, done = batch.astuple()
         if i < num_transitions // batch_size:
             assert len(obs) == batch_size
         else:
             assert len(obs) == num_transitions % batch_size
         for j in range(len(obs)):
             assert obs[j].item() == input_obs[idx_check].item()
             assert next_obs[j].item() == obs[j].item() + 1
             idx_check += 1