def add( self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. :param Batch batch: the input data batch. Its keys must belong to the 7 reserved keys, and "obs", "act", "rew", "done" is required. :param buffer_ids: to make consistent with other buffer's add function; if it is not None, we assume the input batch's first dimension is always 1. Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0. """ # preprocess batch b = Batch() for key in set(self._reserved_keys).intersection(batch.keys()): b.__dict__[key] = batch[key] batch = b assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) stacked_batch = buffer_ids is not None if stacked_batch: assert len(batch) == 1 if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = ( batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] ) # get ptr if stacked_batch: rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done ptr, ep_rew, ep_len, ep_idx = list( map(lambda x: np.array([x]), self._add_index(rew, done)) ) try: self._meta[ptr] = batch except ValueError: stack = not stacked_batch batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) if self._meta.is_empty(): self._meta = _create_value( # type: ignore batch, self.maxsize, stack) else: # dynamic key pops up in batch _alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) self._meta[ptr] = batch return ptr, ep_rew, ep_len, ep_idx
def add( self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. Each of the data's length (first dimension) must equal to the length of buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0. """ # preprocess batch new_batch = Batch() for key in set(self._reserved_keys).intersection(batch.keys()): new_batch.__dict__[key] = batch[key] batch = new_batch assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = batch.obs_next[:, -1] # get index if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( batch.rew[batch_idx], batch.done[batch_idx] ) ptrs.append(ptr + self._offset[buffer_id]) ep_lens.append(ep_len) ep_rews.append(ep_rew) ep_idxs.append(ep_idx + self._offset[buffer_id]) self.last_index[buffer_id] = ptr + self._offset[buffer_id] self._lengths[buffer_id] = len(self.buffers[buffer_id]) ptrs = np.array(ptrs) try: self._meta[ptrs] = batch except ValueError: batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) if self._meta.is_empty(): self._meta = _create_value( # type: ignore batch, self.maxsize, stack=False) else: # dynamic key pops up in batch _alloc_by_keys_diff(self._meta, batch, self.maxsize, False) self._set_batch_for_children() self._meta[ptrs] = batch return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
def collect( self, n_step: Optional[int] = None, n_episode: Optional[int] = None, random: bool = False, render: Optional[float] = None, no_grad: bool = True, ) -> Dict[str, Any]: """Collect a specified number of step or episode with async env setting. This function doesn't collect exactly n_step or n_episode number of transitions. Instead, in order to support async setting, it may collect more than given n_step or n_episode transitions and save into buffer. :param int n_step: how many steps you want to collect. :param int n_episode: how many episodes you want to collect. :param bool random: whether to use random policy for collecting data. Default to False. :param float render: the sleep time between rendering consecutive frames. Default to None (no rendering). :param bool no_grad: whether to retain gradient in policy.forward(). Default to True (no gradient retaining). .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys * ``n/ep`` collected number of episodes. * ``n/st`` collected number of steps. * ``rews`` array of episode reward over collected episodes. * ``lens`` array of episode length over collected episodes. * ``idxs`` array of episode start index in buffer over collected episodes. """ # collect at least n_step or n_episode if n_step is not None: assert n_episode is None, ( "Only one of n_step or n_episode is allowed in Collector." f"collect, got n_step={n_step}, n_episode={n_episode}.") assert n_step > 0 elif n_episode is not None: assert n_episode > 0 else: raise TypeError( "Please specify at least one (either n_step or n_episode) " "in AsyncCollector.collect().") warnings.warn( "Using async setting may collect extra transitions into buffer.") ready_env_ids = self._ready_env_ids start_time = time.time() step_count = 0 episode_count = 0 episode_rews = [] episode_lens = [] episode_start_indices = [] while True: whole_data = self.data self.data = self.data[ready_env_ids] assert len(whole_data) == self.env_num # major difference # restore the state: if the last state is None, it won't store last_state = self.data.policy.pop("hidden_state", None) # get the next action if random: self.data.update(act=[ self._action_space[i].sample() for i in ready_env_ids ]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version # self.data.obs will be used by agent to get result result = self.policy(self.data, last_state) else: result = self.policy(self.data, last_state) # update state / act / policy into self.data policy = result.get("policy", Batch()) assert isinstance(policy, Batch) state = result.get("state", None) if state is not None: policy.hidden_state = state # save state into buffer act = to_numpy(result.act) if self.exploration_noise: act = self.policy.exploration_noise(act, self.data) self.data.update(policy=policy, act=act) # save act/policy before env.step try: whole_data.act[ready_env_ids] = self.data.act whole_data.policy[ready_env_ids] = self.data.policy except ValueError: _alloc_by_keys_diff(whole_data, self.data, self.env_num, False) whole_data[ready_env_ids] = self.data # lots of overhead # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env result = self.env.step(action_remap, ready_env_ids) # type: ignore obs_next, rew, done, info = result # change self.data here because ready_env_ids has changed ready_env_ids = np.array([i["env_id"] for i in info]) self.data = whole_data[ready_env_ids] self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: self.data.update( self.preprocess_fn( obs_next=self.data.obs_next, rew=self.data.rew, done=self.data.done, info=self.data.info, env_id=ready_env_ids, )) if render: self.env.render() if render > 0 and not np.isclose(render, 0): time.sleep(render) # add data into the buffer ptr, ep_rew, ep_len, ep_idx = self.buffer.add( self.data, buffer_ids=ready_env_ids) # collect statistics step_count += len(ready_env_ids) if np.any(done): env_ind_local = np.where(done)[0] env_ind_global = ready_env_ids[env_ind_local] episode_count += len(env_ind_local) episode_lens.append(ep_len[env_ind_local]) episode_rews.append(ep_rew[env_ind_local]) episode_start_indices.append(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_reset = self.preprocess_fn(obs=obs_reset, env_id=env_ind_global).get( "obs", obs_reset) self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) try: whole_data.obs[ready_env_ids] = self.data.obs_next whole_data.rew[ready_env_ids] = self.data.rew whole_data.done[ready_env_ids] = self.data.done whole_data.info[ready_env_ids] = self.data.info except ValueError: _alloc_by_keys_diff(whole_data, self.data, self.env_num, False) self.data.obs = self.data.obs_next whole_data[ready_env_ids] = self.data # lots of overhead self.data = whole_data if (n_step and step_count >= n_step) or \ (n_episode and episode_count >= n_episode): break self._ready_env_ids = ready_env_ids # generate statistics self.collect_step += step_count self.collect_episode += episode_count self.collect_time += max(time.time() - start_time, 1e-9) if episode_count > 0: rews, lens, idxs = list( map(np.concatenate, [episode_rews, episode_lens, episode_start_indices])) else: rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) return { "n/ep": episode_count, "n/st": step_count, "rews": rews, "lens": lens, "idxs": idxs, }