def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._rm_done: batch.done = batch.done * 0. batch = self.compute_nstep_return(batch, buffer, indice, self._target_q, self._gamma, self._n_step, self._rew_norm) return batch
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._rew_norm: bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() if std > self.__eps: batch.rew = (batch.rew - mean) / std if self._rm_done: batch.done = batch.done * 0. return batch
def add( self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. :param Batch batch: the input data batch. Its keys must belong to the 7 reserved keys, and "obs", "act", "rew", "done" is required. :param buffer_ids: to make consistent with other buffer's add function; if it is not None, we assume the input batch's first dimension is always 1. Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0. """ # preprocess batch b = Batch() for key in set(self._reserved_keys).intersection(batch.keys()): b.__dict__[key] = batch[key] batch = b assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) stacked_batch = buffer_ids is not None if stacked_batch: assert len(batch) == 1 if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = ( batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] ) # get ptr if stacked_batch: rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done ptr, ep_rew, ep_len, ep_idx = list( map(lambda x: np.array([x]), self._add_index(rew, done)) ) try: self._meta[ptr] = batch except ValueError: stack = not stacked_batch batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) if self._meta.is_empty(): self._meta = _create_value( # type: ignore batch, self.maxsize, stack) else: # dynamic key pops up in batch _alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) self._meta[ptr] = batch return ptr, ep_rew, ep_len, ep_idx
def add( self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. Each of the data's length (first dimension) must equal to the length of buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0. """ # preprocess batch new_batch = Batch() for key in set(self._reserved_keys).intersection(batch.keys()): new_batch.__dict__[key] = batch[key] batch = new_batch assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = batch.obs_next[:, -1] # get index if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( batch.rew[batch_idx], batch.done[batch_idx] ) ptrs.append(ptr + self._offset[buffer_id]) ep_lens.append(ep_len) ep_rews.append(ep_rew) ep_idxs.append(ep_idx + self._offset[buffer_id]) self.last_index[buffer_id] = ptr + self._offset[buffer_id] self._lengths[buffer_id] = len(self.buffers[buffer_id]) ptrs = np.array(ptrs) try: self._meta[ptrs] = batch except ValueError: batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) if self._meta.is_empty(): self._meta = _create_value( # type: ignore batch, self.maxsize, stack=False) else: # dynamic key pops up in batch _alloc_by_keys_diff(self._meta, batch, self.maxsize, False) self._set_batch_for_children() self._meta[ptrs] = batch return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._rm_done: batch.done = batch.done * 0. if self._rew_norm: if self.norm_func is None: bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() if np.isclose(std, 0): mean, std = 0, 1 batch.rew = (batch.rew - mean) / std else: batch.rew = self.norm_func(batch.rew) return batch
def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) assert str(buf) == buf.__class__.__name__ + '()' obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) buf.add( Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info)) obs = obs_next assert len(buf) == min(bufsize, i + 1) assert buf.act.dtype == int assert buf.act.shape == (bufsize, 1) data, indices = buf.sample(bufsize * 2) assert (indices < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() b = ReplayBuffer(size=10) # neg bsz should return empty index assert b.sample_indices(-1).tolist() == [] ptr, ep_rew, ep_len, ep_idx = b.add( Batch(obs=1, act=1, rew=1, done=1, obs_next='str', info={ 'a': 3, 'b': { 'c': 5.0 } })) assert b.obs[0] == 1 assert b.done[0] assert b.obs_next[0] == 'str' assert np.all(b.obs[1:] == 0) assert np.all(b.obs_next[1:] == np.array(None)) assert b.info.a[0] == 3 and b.info.a.dtype == int assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float assert np.all(b.info.b.c[1:] == 0.0) assert ptr.shape == (1, ) and ptr[0] == 0 assert ep_rew.shape == (1, ) and ep_rew[0] == 1 assert ep_len.shape == (1, ) and ep_len[0] == 1 assert ep_idx.shape == (1, ) and ep_idx[0] == 0 # test extra keys pop up, the buffer should handle it dynamically batch = Batch(obs=2, act=2, rew=2, done=0, obs_next="str2", info={ "a": 4, "d": { "e": -np.inf } }) b.add(batch) info_keys = ["a", "b", "d"] assert set(b.info.keys()) == set(info_keys) assert b.info.a[1] == 4 and b.info.b.c[1] == 0 assert b.info.d.e[1] == -np.inf # test batch-style adding method, where len(batch) == 1 batch.done = [1] batch.info.e = np.zeros([1, 4]) batch = Batch.stack([batch]) ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) assert ptr.shape == (1, ) and ptr[0] == 2 assert ep_rew.shape == (1, ) and ep_rew[0] == 4 assert ep_len.shape == (1, ) and ep_len[0] == 2 assert ep_idx.shape == (1, ) and ep_idx[0] == 1 assert set(b.info.keys()) == set(info_keys + ["e"]) assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): b[22] # test prev / next assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) batch.done = [0] b.add(batch, buffer_ids=[0]) assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])