def _batch_set_item( source: Batch, indices: np.ndarray, target: Batch, size: int ) -> None: # for any key chain k, there are four cases # 1. source[k] is non-reserved, but target[k] does not exist or is reserved # 2. source[k] does not exist or is reserved, but target[k] is non-reserved # 3. both source[k] and target[k] are non-reserved # 4. both source[k] and target[k] do not exist or are reserved, do nothing. # A special case in case 4, if target[k] is reserved but source[k] does # not exist, make source[k] reserved, too. for k, vt in target.items(): if not isinstance(vt, Batch) or not vt.is_empty(): # target[k] is non-reserved vs = source.get(k, Batch()) if isinstance(vs, Batch): if vs.is_empty(): # case 2, use __dict__ to avoid many type checks source.__dict__[k] = _create_value(vt[0], size) else: assert isinstance(vt, Batch) _batch_set_item(source.__dict__[k], indices, vt, size) else: # target[k] is reserved # case 1 or special case of case 4 if k not in source.__dict__: source.__dict__[k] = Batch() continue source.__dict__[k][indices] = vt
def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] except KeyError: self._meta.__dict__[name] = _create_value(inst, self._maxsize) value = self._meta.__dict__[name] if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape: raise ValueError( "Cannot add data to a buffer with different shape, with key " f"{name}, expect {value.shape[1:]}, given {inst.shape}.") try: value[self._index] = inst except KeyError: for key in set(inst.keys()).difference(value.__dict__.keys()): value.__dict__[key] = _create_value(inst[key], self._maxsize) value[self._index] = inst
def add( self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. :param Batch batch: the input data batch. Its keys must belong to the 7 reserved keys, and "obs", "act", "rew", "done" is required. :param buffer_ids: to make consistent with other buffer's add function; if it is not None, we assume the input batch's first dimension is always 1. Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0. """ # preprocess batch b = Batch() for key in set(self._reserved_keys).intersection(batch.keys()): b.__dict__[key] = batch[key] batch = b assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) stacked_batch = buffer_ids is not None if stacked_batch: assert len(batch) == 1 if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = ( batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] ) # get ptr if stacked_batch: rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done ptr, ep_rew, ep_len, ep_idx = list( map(lambda x: np.array([x]), self._add_index(rew, done)) ) try: self._meta[ptr] = batch except ValueError: stack = not stacked_batch batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) if self._meta.is_empty(): self._meta = _create_value( # type: ignore batch, self.maxsize, stack) else: # dynamic key pops up in batch _alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) self._meta[ptr] = batch return ptr, ep_rew, ep_len, ep_idx
def add( self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. Each of the data's length (first dimension) must equal to the length of buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0. """ # preprocess batch new_batch = Batch() for key in set(self._reserved_keys).intersection(batch.keys()): new_batch.__dict__[key] = batch[key] batch = new_batch assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = batch.obs_next[:, -1] # get index if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( batch.rew[batch_idx], batch.done[batch_idx] ) ptrs.append(ptr + self._offset[buffer_id]) ep_lens.append(ep_len) ep_rews.append(ep_rew) ep_idxs.append(ep_idx + self._offset[buffer_id]) self.last_index[buffer_id] = ptr + self._offset[buffer_id] self._lengths[buffer_id] = len(self.buffers[buffer_id]) ptrs = np.array(ptrs) try: self._meta[ptrs] = batch except ValueError: batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) if self._meta.is_empty(): self._meta = _create_value( # type: ignore batch, self.maxsize, stack=False) else: # dynamic key pops up in batch _alloc_by_keys_diff(self._meta, batch, self.maxsize, False) self._set_batch_for_children() self._meta[ptrs] = batch return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
def update(self, buffer: "ReplayBuffer") -> np.ndarray: """Move the data from the given buffer to current buffer. Return the updated indices. If update fails, return an empty array. """ if len(buffer) == 0 or self.maxsize == 0: return np.array([], int) stack_num, buffer.stack_num = buffer.stack_num, 1 from_indices = buffer.sample_indices(0) # get all available indices buffer.stack_num = stack_num if len(from_indices) == 0: return np.array([], int) to_indices = [] for _ in range(len(from_indices)): to_indices.append(self._index) self.last_index[0] = self._index self._index = (self._index + 1) % self.maxsize self._size = min(self._size + 1, self.maxsize) to_indices = np.array(to_indices) if self._meta.is_empty(): self._meta = _create_value( # type: ignore buffer._meta, self.maxsize, stack=False) self._meta[to_indices] = buffer._meta[from_indices] return to_indices
def _buffer_allocator(self, key: List[str], value: Any) -> None: """Allocate memory on buffer._meta for new (key, value) pair.""" data = self._meta for k in key[:-1]: data = data[k] data[key[-1]] = _create_value(value, self.maxsize)