def _check(num_transitions, batch_size, permute): dummy = np.zeros((num_transitions, 1)) input_obs = np.arange(num_transitions)[:, None] transitions = TransitionBatch(input_obs, dummy, input_obs + 1, dummy, dummy) it = replay_buffer.BootstrapIterator( transitions, batch_size, num_members, permute_indices=permute ) member_contents = [[] for _ in range(num_members)] for batch in it: obs, *_ = batch.astuple() assert obs.shape[0] == num_members assert obs.shape[2] == 1 for i in range(num_members): member_contents[i].extend(obs[i].squeeze(1).tolist()) all_elements = list(range(num_transitions)) for i in range(num_members): if permute: # this checks that all elements are present but shuffled sorted_content = sorted(member_contents[i]) assert sorted_content == all_elements assert member_contents[i] != all_elements else: # check that it did sampling with replacement assert len(member_contents[i]) == num_transitions assert min(member_contents[i]) >= 0 assert max(member_contents[i]) < num_transitions assert member_contents[i] != all_elements # this checks that all member samples are different for j in range(i + 1, num_members): assert member_contents[i] != member_contents[j]
def _check(num_transitions, batch_size): dummy = np.zeros((num_transitions, 1)) input_obs = np.arange(num_transitions)[:, None] transitions = TransitionBatch(input_obs, dummy, dummy, dummy, dummy) it = replay_buffer.TransitionIterator( transitions, batch_size, shuffle_each_epoch=True ) all_obs = [] for i, batch in enumerate(it): obs, *_ = batch.astuple() for j in range(len(obs)): all_obs.append(obs[j].item()) all_obs_sorted = sorted(all_obs) assert any([a != b for a, b in zip(all_obs, all_obs_sorted)]) assert all([a == b for a, b in zip(all_obs_sorted, range(num_transitions))]) # the second time the order should be different all_obs_second = [] for i, batch in enumerate(it): obs, *_ = batch.astuple() for j in range(len(obs)): all_obs_second.append(obs[j].item()) assert any([a != b for a, b in zip(all_obs, all_obs_second)])
def test_transition_batch_getitem(): how_many = 10 obs = np.random.randn(how_many, 4) act = np.random.randn(how_many, 2) next_obs = np.random.randn(how_many, 4) rewards = np.random.randn(how_many, 1) dones = np.random.randn(how_many, 1) transitions = TransitionBatch(obs, act, next_obs, rewards, dones) for i in range(how_many): o, a, no, r, d = transitions[i].astuple() assert np.allclose(o, obs[i]) assert np.allclose(a, act[i]) assert np.allclose(no, next_obs[i]) assert np.allclose(r, rewards[i]) assert np.allclose(d, dones[i]) o, a, no, r, d = transitions[i:].astuple() assert np.allclose(o, obs[i:]) o, a, no, r, d = transitions[:i].astuple() assert np.allclose(o, obs[:i]) for j in range(i + 1, how_many): o, a, no, r, d = transitions[i:j].astuple() assert np.allclose(o, obs[i:j]) for sz in range(1, how_many): indices = np.random.choice(how_many, size=5) o, a, no, r, d = transitions[indices].astuple() assert np.allclose(o, obs[indices])
def _batch_from_indices(self, indices: Sized) -> TransitionBatch: obs = self.obs[indices] next_obs = self.next_obs[indices] action = self.action[indices] reward = self.reward[indices] done = self.done[indices] return TransitionBatch(obs, action, next_obs, reward, done)
def _consolidate_batches(batches: Sequence[TransitionBatch]) -> TransitionBatch: len_batches = len(batches) b0 = batches[0] obs = np.empty((len_batches,) + b0.obs.shape, dtype=b0.obs.dtype) act = np.empty((len_batches,) + b0.act.shape, dtype=b0.act.dtype) next_obs = np.empty((len_batches,) + b0.obs.shape, dtype=b0.obs.dtype) rewards = np.empty((len_batches,) + b0.rewards.shape, dtype=np.float32) dones = np.empty((len_batches,) + b0.dones.shape, dtype=bool) for i, b in enumerate(batches): obs[i] = b.obs act[i] = b.act next_obs[i] = b.next_obs rewards[i] = b.rewards dones[i] = b.dones return TransitionBatch(obs, act, next_obs, rewards, dones)
def _check_for_capacity_and_batch_size(num_transitions, batch_size): dummy = np.zeros((num_transitions, 1)) input_obs = np.arange(num_transitions)[:, None] transitions = TransitionBatch(input_obs, dummy, input_obs + 1, dummy, dummy) it = replay_buffer.TransitionIterator(transitions, batch_size) assert len(it) == int(np.ceil(num_transitions / bs)) idx_check = 0 for i, batch in enumerate(it): obs, action, next_obs, reward, done = batch.astuple() if i < num_transitions // batch_size: assert len(obs) == batch_size else: assert len(obs) == num_transitions % batch_size for j in range(len(obs)): assert obs[j].item() == input_obs[idx_check].item() assert next_obs[j].item() == obs[j].item() + 1 idx_check += 1