def get(self, indice: Union[slice, np.ndarray], key: str, stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]: """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is given from buffer initialization procedure. """ if stack_num is None: stack_num = self._stack if not isinstance(indice, np.ndarray): if np.isscalar(indice): indice = np.array(indice) elif isinstance(indice, slice): indice = np.arange( 0 if indice.start is None else self._size - indice.start if indice.start < 0 else indice.start, self._size if indice.stop is None else self._size - indice.stop if indice.stop < 0 else indice.stop, 1 if indice.step is None else indice.step) # set last frame done to True last_index = (self._index - 1 + self._size) % self._size last_done, self.done[last_index] = self.done[last_index], True if key == 'obs_next' and not self._save_s_: indice += 1 - self.done[indice].astype(np.int) indice[indice == self._size] = 0 key = 'obs' if stack_num == 0: self.done[last_index] = last_done if key in self._meta: return { k: self.__dict__['_' + key + '@' + k][indice] for k in self._meta[key] } else: return self.__dict__[key][indice] if key in self._meta: many_keys = self._meta[key] stack = {k: [] for k in self._meta[key]} else: stack = [] many_keys = None for i in range(stack_num): if many_keys is not None: for k_ in many_keys: k__ = '_' + key + '@' + k_ stack[k_] = [self.__dict__[k__][indice]] + stack[k_] else: stack = [self.__dict__[key][indice]] + stack pre_indice = indice - 1 pre_indice[pre_indice == -1] = self._size - 1 indice = pre_indice + self.done[pre_indice].astype(np.int) indice[indice == self._size] = 0 self.done[last_index] = last_done if many_keys is not None: for k in stack: stack[k] = np.stack(stack[k], axis=1) stack = Batch(**stack) else: stack = np.stack(stack, axis=1) return stack
class ReplayBuffer: """:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. The current implementation of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`: * ``obs`` the observation of step :math:`t` ; * ``act`` the action of step :math:`t` ; * ``rew`` the reward of step :math:`t` ; * ``done`` the done flag of step :math:`t` ; * ``obs_next`` the observation of step :math:`t+1` ; * ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` \ function returns 4 arguments, and the last one is ``info``); * ``policy`` the data computed by policy in step :math:`t`; The following code snippet illustrates its usage: :: >>> import numpy as np >>> from tianshou.data import ReplayBuffer >>> buf = ReplayBuffer(size=20) >>> for i in range(3): ... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) >>> buf.obs # since we set size = 20, len(buf.obs) == 20. array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) >>> # but there are only three valid items, so len(buf) == 3. >>> len(buf) 3 >>> buf2 = ReplayBuffer(size=10) >>> for i in range(15): ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) >>> len(buf2) 10 >>> buf2.obs # since its size = 10, it only stores the last 10 steps' result. array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.]) >>> # move buf2's result into buf (meanwhile keep it chronologically) >>> buf.update(buf2) array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 0., 0., 0., 0., 0., 0., 0.]) >>> # get a random sample from buffer >>> # the batch_data is equal to buf[incide]. >>> batch_data, indice = buf.sample(batch_size=4) >>> batch_data.obs == buf[indice].obs array([ True, True, True, True]) :class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in atari tasks), and multi-modal observation (see issue#38): :: >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) >>> for i in range(16): ... done = i % 5 == 0 ... buf.add(obs={'id': i}, act=i, rew=i, done=done, ... obs_next={'id': i + 1}) >>> print(buf) # you can see obs_next is not saved in buf ReplayBuffer( act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]), info: Batch(), obs: Batch( id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), ), policy: Batch(), rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), ) >>> index = np.arange(len(buf)) >>> print(buf.get(index, 'obs').id) [[ 7. 7. 8. 9.] [ 7. 8. 9. 10.] [11. 11. 11. 11.] [11. 11. 11. 12.] [11. 11. 12. 13.] [11. 12. 13. 14.] [12. 13. 14. 15.] [ 7. 7. 7. 7.] [ 7. 7. 7. 8.]] >>> # here is another way to get the stacked data >>> # (stack only for obs and obs_next) >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() 0.0 >>> # we can get obs_next through __getitem__, even if it doesn't exist >>> print(buf[:].obs_next.id) [[ 7. 8. 9. 10.] [ 7. 8. 9. 10.] [11. 11. 11. 12.] [11. 11. 12. 13.] [11. 12. 13. 14.] [12. 13. 14. 15.] [12. 13. 14. 15.] [ 7. 7. 7. 8.] [ 7. 7. 8. 9.]] :param int size: the size of replay buffer. :param int stack_num: the frame-stack sampling argument, should be greater than 1, defaults to 0 (no stacking). :param bool ignore_obs_next: whether to store obs_next, defaults to ``False``. :param bool sample_avail: the parameter indicating sampling only available index when using frame-stack sampling method, defaults to ``False``. This feature is not supported in Prioritized Replay Buffer currently. """ def __init__(self, size: int, stack_num: Optional[int] = 0, ignore_obs_next: bool = False, sample_avail: bool = False, **kwargs) -> None: super().__init__() self._maxsize = size self._stack = stack_num assert stack_num != 1, 'stack_num should greater than 1' self._avail = sample_avail and stack_num > 1 self._avail_index = [] self._save_s_ = not ignore_obs_next self._index = 0 self._size = 0 self._meta = Batch() self._max_ep_len = kwargs.get( 'max_ep_len', None ) # used to identify whether the end is terminated by time limit. self._ens_num = kwargs.get('ens_num', None) # if not none, add mask self._ngu = kwargs.get('ngu', False) self._rand2 = kwargs.get('rand2', False) self.non_episodic = kwargs.get('non_episodic', False) self.reset() def __len__(self) -> int: """Return len(self).""" return self._size def __repr__(self) -> str: """Return str(self).""" return self.__class__.__name__ + self._meta.__repr__()[5:] def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" if key not in self._meta.__dict__: return return self._meta.__dict__[key] def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] except KeyError: self._meta.__dict__[name] = _create_value(inst, self._maxsize) value = self._meta.__dict__[name] if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape: raise ValueError( "Cannot add data to a buffer with different shape, key: " f"{name}, expect shape: {value.shape[1:]}, " f"given shape: {inst.shape}.") try: value[self._index] = inst except KeyError: for key in set(inst.keys()).difference(value.__dict__.keys()): value.__dict__[key] = _create_value(inst[key], self._maxsize) value[self._index] = inst def _get_stack_num(self): return self._stack def _set_stack_num(self, num): self._stack = num def update(self, buffer: 'ReplayBuffer') -> None: """Move the data from the given buffer to self.""" if len(buffer) == 0: return i = begin = buffer._index % len(buffer) origin = buffer._get_stack_num() buffer._set_stack_num(0) _len = 0 while True: _len += 1 if self._max_ep_len is not None and _len >= self._max_ep_len: buffer.done[i] *= 0 if ((i + 1) % len(buffer) == begin): buffer.done_bk[i] = 1 if self.non_episodic and ((i + 1) % len(buffer) == begin) and buffer.done[i] != 0: buffer.obs_next[i] = buffer.obs[begin] if not self._save_s_: buffer.obs_next[i] = None self.add(**buffer[i]) i = (i + 1) % len(buffer) if i == begin: break buffer._set_stack_num(origin) def add(self, obs: Union[dict, Batch, np.ndarray], act: Union[np.ndarray, float], rew: Union[int, float], done: bool, obs_next: Optional[Union[dict, Batch, np.ndarray]] = None, info: dict = {}, policy: Optional[Union[dict, Batch]] = {}, **kwargs) -> None: """Add a batch of data into replay buffer.""" assert isinstance(info, (dict, Batch)), \ 'You should return a dict in the last argument of env.step().' self._add_to_buffer('obs', obs) self._add_to_buffer('act', act) self._add_to_buffer('rew', rew) self._add_to_buffer('done', done) self._add_to_buffer('done_bk', kwargs.get('done_bk', done)) if self._save_s_: if obs_next is None: obs_next = Batch() self._add_to_buffer('obs_next', obs_next) self._add_to_buffer('info', info) self._add_to_buffer('policy', policy) if self._ens_num is not None: mask = np.random.randint(2, size=self._ens_num) self._add_to_buffer('mask', mask) if self._ngu: self._add_to_buffer('erew', 0) self._add_to_buffer('goal', 0) if self._rand2: self._add_to_buffer('irew', 0) if 'log_prob' in kwargs: self._add_to_buffer('log_prob', kwargs.get('log_prob', 0)) if 'alpha' in kwargs: self._add_to_buffer('alpha', kwargs.get('alpha', 0)) if 'int_rew' in kwargs: self._add_to_buffer('int_rew', kwargs.get('int_rew', 0)) # maintain available index for frame-stack sampling if self._avail: # update current frame avail = sum(self.done[i] for i in range(self._index - self._stack + 1, self._index)) == 0 if self._size < self._stack - 1: avail = False if avail and self._index not in self._avail_index: self._avail_index.append(self._index) elif not avail and self._index in self._avail_index: self._avail_index.remove(self._index) # remove the later available frame because of broken storage t = (self._index + self._stack - 1) % self._maxsize if t in self._avail_index: self._avail_index.remove(t) if self._maxsize > 0: self._size = min(self._size + 1, self._maxsize) self._index = (self._index + 1) % self._maxsize else: self._size = self._index = self._index + 1 def reset(self) -> None: """Clear all the data in replay buffer.""" self._index = 0 self._size = 0 self._avail_index = [] def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with size equal to batch_size. \ Return all the data in the buffer if batch_size is ``0``. :return: Sample data and its corresponding index inside the buffer. """ if batch_size > 0: _all = self._avail_index if self._avail else self._size indice = np.random.choice(_all, batch_size) else: if self._avail: indice = np.array(self._avail_index) else: indice = np.concatenate([ np.arange(self._index, self._size), np.arange(0, self._index), ]) assert len(indice) > 0, 'No available indice can be sampled.' return self[indice], indice def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]: """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is given from buffer initialization procedure. """ if stack_num is None: stack_num = self._stack if isinstance(indice, slice): indice = np.arange( 0 if indice.start is None else self._size - indice.start if indice.start < 0 else indice.start, self._size if indice.stop is None else self._size - indice.stop if indice.stop < 0 else indice.stop, 1 if indice.step is None else indice.step) else: indice = np.array(indice, copy=True) # set last frame done to True last_index = (self._index - 1 + self._size) % self._size last_done, self.done[last_index] = self.done[last_index], True if key == 'obs_next' and (not self._save_s_ or self.obs_next is None): indice += 1 - self.done[indice].astype(np.int) indice[indice == self._size] = 0 key = 'obs' val = self._meta.__dict__[key] try: if stack_num > 0: stack = [] for _ in range(stack_num): stack = [val[indice]] + stack pre_indice = np.asarray(indice - 1) pre_indice[pre_indice == -1] = self._size - 1 indice = np.asarray(pre_indice + self.done[pre_indice].astype(np.int)) indice[indice == self._size] = 0 if isinstance(val, Batch): stack = Batch.stack(stack, axis=indice.ndim) else: stack = np.stack(stack, axis=indice.ndim) else: stack = val[indice] except IndexError as e: stack = Batch() if not isinstance(val, Batch) or len(val.__dict__) > 0: raise e self.done[last_index] = last_done return stack def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch: """Return a data batch: self[index]. If stack_num is set to be > 0, return the stacked obs and obs_next with shape [batch, len, ...]. """ ret = Batch(obs=self.get(index, 'obs'), act=self.act[index], rew=self.rew[index], done=self.done[index], done_bk=self.done_bk[index], obs_next=self.get(index, 'obs_next'), info=self.get(index, 'info'), policy=self.get(index, 'policy')) if self.log_prob is not None: ret.log_prob = self.log_prob[index] if self.alpha is not None: ret.alpha = self.alpha[index] # if hasattr(self, 'int_rew'): # ret.int_rew = self.int_rew[index] if self._ens_num is not None: ret.mask = self.mask[index] if self._ngu: ret.erew = self.erew[index] ret.goal = self.goal[index] if self._rand2: ret.irew = self.irew[index] return ret