class Collector(object): """The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to ``None``, it will automatically assign a small-size :class:`~tianshou.data.ReplayBuffer`. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42, defaults to ``None``. :param int stat_size: for the moving average of recording speed, defaults to 100. The ``preprocess_fn`` is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed in :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the collector resets the environment. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in "test/base/test_collector.py". Example: :: policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) buffers = [ReplayBuffer(size=5000) for _ in range(3)] # you can also pass a list of replay buffer to collector, for multi-env # collector = Collector(policy, envs, buffer=buffers) collector = Collector(policy, envs, buffer=replay_buffer) # collect at least 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) # sample data with a given number of batch-size: batch_data = collector.sample(batch_size=64) # policy.learn(batch_data) # btw, vanilla policy gradient only # supports on-policy training, so here we pick all data in the buffer batch_data = collector.sample(batch_size=0) policy.learn(batch_data) # on-policy algorithms use the collected data only once, so here we # clear the buffer collector.reset_buffer() For the scenario of collecting data from multiple environments to a single buffer, the cache buffers will turn on automatically. It may return the data more than the given limitation. .. note:: Please make sure the given environment has a time limitation. """ def __init__(self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], buffer: Optional[Union[ReplayBuffer, List[ReplayBuffer]]] = None, preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, stat_size: Optional[int] = 100, **kwargs) -> None: super().__init__() self.env = env self.env_num = 1 self.collect_time = 0 self.collect_step = 0 self.collect_episode = 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn # if preprocess_fn is None: # def _prep(**kwargs): # return kwargs # self.preprocess_fn = _prep self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) self._multi_buf = False # True if buf is a list # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) if isinstance(self.buffer, list): assert len(self.buffer) == self.env_num, \ 'The number of data buffer does not match the number of ' \ 'input env.' self._multi_buf = True elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None: self._cached_buf = [ ListReplayBuffer() for _ in range(self.env_num)] else: raise TypeError('The buffer in data collector is invalid!') self.stat_size = stat_size self.reset() def reset(self) -> None: """Reset all related variables in the collector.""" self.reset_env() self.reset_buffer() # state over batch is either a list, an np.ndarray, or a torch.Tensor self.state = None self.step_speed = MovAvg(self.stat_size) self.episode_speed = MovAvg(self.stat_size) self.collect_step = 0 self.collect_episode = 0 self.collect_time = 0 def reset_buffer(self) -> None: """Reset the main data buffer.""" if self._multi_buf: for b in self.buffer: b.reset() else: if self.buffer is not None: self.buffer.reset() def get_env_num(self) -> int: """Return the number of environments the collector have.""" return self.env_num def reset_env(self) -> None: """Reset all of the environment(s)' states and reset all of the cache buffers (if need). """ self._obs = self.env.reset() if not self._multi_env: self._obs = self._make_batch(self._obs) if self.preprocess_fn: self._obs = self.preprocess_fn(obs=self._obs).get('obs', self._obs) self._act = self._rew = self._done = self._info = None if self._multi_env: self.reward = np.zeros(self.env_num) self.length = np.zeros(self.env_num) else: self.reward, self.length = 0, 0 for b in self._cached_buf: b.reset() def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: """Reset all the seed(s) of the given environment(s).""" if hasattr(self.env, 'seed'): return self.env.seed(seed) def render(self, **kwargs) -> None: """Render all the environment(s).""" if hasattr(self.env, 'render'): return self.env.render(**kwargs) def close(self) -> None: """Close the environment(s).""" if hasattr(self.env, 'close'): self.env.close() def _make_batch(self, data: Any) -> np.ndarray: """Return [data].""" if isinstance(data, np.ndarray): return data[None] else: return np.array([data]) def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset self.state[id].""" if self.state is None: return if isinstance(self.state, list): self.state[id] = None elif isinstance(self.state, (dict, Batch)): for k in self.state.keys(): if isinstance(self.state[k], list): self.state[k][id] = None elif isinstance(self.state[k], (torch.Tensor, np.ndarray)): self.state[k][id] = 0 elif isinstance(self.state, (torch.Tensor, np.ndarray)): self.state[id] = 0 def collect(self, n_step: int = 0, n_episode: Union[int, List[int]] = 0, random: bool = False, render: Optional[float] = None, log_fn: Optional[Callable[[dict], None]] = None ) -> Dict[str, float]: """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect (in each environment). :type n_episode: int or list :param bool random: whether to use random policy for collecting data, defaults to ``False``. :param float render: the sleep time between rendering consecutive frames, defaults to ``None`` (no rendering). :param function log_fn: a function which receives env info, typically for tensorboard logging. .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. * ``v/st`` the speed of steps per second. * ``v/ep`` the speed of episode per second. * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ warning_count = 0 if not self._multi_env: n_episode = np.sum(n_episode) start_time = time.time() assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ "One and only one collection number specification is permitted!" cur_step = 0 cur_episode = np.zeros(self.env_num) if self._multi_env else 0 reward_sum = 0 length_sum = 0 while True: if warning_count >= 100000: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) batch = Batch( obs=self._obs, act=self._act, rew=self._rew, done=self._done, obs_next=None, info=self._info, policy=None) if random: action_space = self.env.action_space if isinstance(action_space, list): result = Batch(act=[a.sample() for a in action_space]) else: result = Batch(act=self._make_batch(action_space.sample())) else: with torch.no_grad(): result = self.policy(batch, self.state) self.state = result.get('state', None) self._policy = to_numpy(result.policy) \ if hasattr(result, 'policy') else [{}] * self.env_num self._act = to_numpy(result.act) obs_next, self._rew, self._done, self._info = self.env.step( self._act if self._multi_env else self._act[0]) if not self._multi_env: obs_next = self._make_batch(obs_next) self._rew = self._make_batch(self._rew) self._done = self._make_batch(self._done) self._info = self._make_batch(self._info) if log_fn: log_fn(self._info if self._multi_env else self._info[0]) if render: self.env.render() if render > 0: time.sleep(render) self.length += 1 self.reward += self._rew if self.preprocess_fn: result = self.preprocess_fn( obs=self._obs, act=self._act, rew=self._rew, done=self._done, obs_next=obs_next, info=self._info, policy=self._policy) self._obs = result.get('obs', self._obs) self._act = result.get('act', self._act) self._rew = result.get('rew', self._rew) self._done = result.get('done', self._done) obs_next = result.get('obs_next', obs_next) self._info = result.get('info', self._info) self._policy = result.get('policy', self._policy) if self._multi_env: for i in range(self.env_num): data = { 'obs': self._obs[i], 'act': self._act[i], 'rew': self._rew[i], 'done': self._done[i], 'obs_next': obs_next[i], 'info': self._info[i], 'policy': self._policy[i]} if self._cached_buf: warning_count += 1 self._cached_buf[i].add(**data) elif self._multi_buf: warning_count += 1 self.buffer[i].add(**data) cur_step += 1 else: warning_count += 1 if self.buffer is not None: self.buffer.add(**data) cur_step += 1 if self._done[i]: if n_step != 0 or np.isscalar(n_episode) or \ cur_episode[i] < n_episode[i]: cur_episode[i] += 1 reward_sum += self.reward[i] length_sum += self.length[i] if self._cached_buf: cur_step += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) self.reward[i], self.length[i] = 0, 0 if self._cached_buf: self._cached_buf[i].reset() self._reset_state(i) if sum(self._done): obs_next = self.env.reset(np.where(self._done)[0]) if self.preprocess_fn: obs_next = self.preprocess_fn(obs=obs_next).get( 'obs', obs_next) if n_episode != 0: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ cur_episode.sum() >= n_episode: break else: if self.buffer is not None: self.buffer.add( self._obs[0], self._act[0], self._rew[0], self._done[0], obs_next[0], self._info[0], self._policy[0]) cur_step += 1 if self._done: cur_episode += 1 reward_sum += self.reward[0] length_sum += self.length self.reward, self.length = 0, 0 self.state = None obs_next = self._make_batch(self.env.reset()) if self.preprocess_fn: obs_next = self.preprocess_fn(obs=obs_next).get( 'obs', obs_next) if n_episode != 0 and cur_episode >= n_episode: break if n_step != 0 and cur_step >= n_step: break self._obs = obs_next self._obs = obs_next if self._multi_env: cur_episode = sum(cur_episode) duration = max(time.time() - start_time, 1e-9) self.step_speed.add(cur_step / duration) self.episode_speed.add(cur_episode / duration) self.collect_step += cur_step self.collect_episode += cur_episode self.collect_time += duration if isinstance(n_episode, list): n_episode = np.sum(n_episode) else: n_episode = max(cur_episode, 1) return { 'n/ep': cur_episode, 'n/st': cur_step, 'v/st': self.step_speed.get(), 'v/ep': self.episode_speed.get(), 'rew': reward_sum / n_episode, 'len': length_sum / n_episode, } def sample(self, batch_size: int) -> Batch: """Sample a data batch from the internal replay buffer. It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size. """ if self._multi_buf: if batch_size > 0: lens = [len(b) for b in self.buffer] total = sum(lens) batch_index = np.random.choice( len(self.buffer), batch_size, p=np.array(lens) / total) else: batch_index = np.array([]) batch_data = Batch() for i, b in enumerate(self.buffer): cur_batch = (batch_index == i).sum() if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) batch_data.append(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data
class Collector(object): """The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to ``None``, it will automatically assign a small-size :class:`~tianshou.data.ReplayBuffer`. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults to ``None``. :param int stat_size: for the moving average of recording speed, defaults to 100. :param BaseNoise action_noise: add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose. :param function reward_metric: to be used in multi-agent RL. The reward to report is of shape [agent_num], but we need to return a single scalar to monitor training. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. By default, the behavior is to select the reward of agent 1. The ``preprocess_fn`` is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed in :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the collector resets the environment. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in "test/base/test_collector.py". Example: :: policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) buffers = [ReplayBuffer(size=5000) for _ in range(3)] # you can also pass a list of replay buffer to collector, for multi-env # collector = Collector(policy, envs, buffer=buffers) collector = Collector(policy, envs, buffer=replay_buffer) # collect at least 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) # sample data with a given number of batch-size: batch_data = collector.sample(batch_size=64) # policy.learn(batch_data) # btw, vanilla policy gradient only # supports on-policy training, so here we pick all data in the buffer batch_data = collector.sample(batch_size=0) policy.learn(batch_data) # on-policy algorithms use the collected data only once, so here we # clear the buffer collector.reset_buffer() For the scenario of collecting data from multiple environments to a single buffer, the cache buffers will turn on automatically. It may return the data more than the given limitation. .. note:: Please make sure the given environment has a time limitation. """ def __init__( self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, stat_size: Optional[int] = 100, action_noise: Optional[BaseNoise] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: super().__init__() self.env = env self.env_num = 1 self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) self._cached_buf = [ ListReplayBuffer() for _ in range(self.env_num) ] self.stat_size = stat_size self._action_noise = action_noise self._rew_metric = reward_metric or Collector._default_rew_metric self.reset() @staticmethod def _default_rew_metric(x): # this internal function is designed for single-agent RL # for multi-agent RL, a reward_metric must be provided assert np.asanyarray(x).size == 1, \ 'Please specify the reward_metric ' \ 'since the reward is not a scalar.' return x def reset(self) -> None: """Reset all related variables in the collector.""" self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={}) self.reset_env() self.reset_buffer() self.step_speed = MovAvg(self.stat_size) self.episode_speed = MovAvg(self.stat_size) self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 if self._action_noise is not None: self._action_noise.reset() def reset_buffer(self) -> None: """Reset the main data buffer.""" if self.buffer is not None: self.buffer.reset() def get_env_num(self) -> int: """Return the number of environments the collector have.""" return self.env_num def reset_env(self) -> None: """Reset all of the environment(s)' states and reset all of the cache buffers (if need). """ obs = self.env.reset() if not self._multi_env: obs = self._make_batch(obs) if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get('obs', obs) self.data.obs = obs self.reward = 0. # will be specified when the first data is ready self.length = np.zeros(self.env_num) for b in self._cached_buf: b.reset() def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: """Reset all the seed(s) of the given environment(s).""" return self.env.seed(seed) def render(self, **kwargs) -> None: """Render all the environment(s).""" return self.env.render(**kwargs) def close(self) -> None: """Close the environment(s).""" self.env.close() def _make_batch(self, data: Any) -> np.ndarray: """Return [data].""" if isinstance(data, np.ndarray): return data[None] else: return np.array([data]) def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset self.data.state[id].""" state = self.data.state # it is a reference if isinstance(state, torch.Tensor): state[id].zero_() elif isinstance(state, np.ndarray): state[id] = None if state.dtype == np.object else 0 elif isinstance(state, Batch): state.empty_(id) def collect( self, n_step: int = 0, n_episode: Union[int, List[int]] = 0, random: bool = False, render: Optional[float] = None, log_fn: Optional[Callable[[dict], None]] = None) -> Dict[str, float]: """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect (in each environment). :type n_episode: int or list :param bool random: whether to use random policy for collecting data, defaults to ``False``. :param float render: the sleep time between rendering consecutive frames, defaults to ``None`` (no rendering). :param function log_fn: a function which receives env info, typically for tensorboard logging. .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. * ``v/st`` the speed of steps per second. * ``v/ep`` the speed of episode per second. * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ if not self._multi_env: n_episode = np.sum(n_episode) start_time = time.time() assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ "One and only one collection number specification is permitted!" cur_step, cur_episode = 0, np.zeros(self.env_num) reward_sum, length_sum = 0., 0 # change ty1_succ_rate_1 = 0. ty1_succ_rate_2 = 0. ty1_succ_rate_3 = 0. Q_len_1 = 0. Q_len_2 = 0. Q_len_3 = 0. energy_effi_1 = 0. energy_effi_2 = 0. energy_effi_3 = 0. avg_rate = 0. avg_power = 0. while True: if cur_step >= 100000 and cur_episode.sum() == 0: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) # restore the state and the input data last_state = self.data.state if last_state.is_empty(): last_state = None self.data.update(state=Batch(), obs_next=Batch(), policy=Batch()) # calculate the next action if random: action_space = self.env.action_space if isinstance(action_space, list): result = Batch(act=[a.sample() for a in action_space]) else: result = Batch(act=self._make_batch(action_space.sample())) else: with torch.no_grad(): result = self.policy(self.data, last_state) # convert None to Batch(), since None is reserved for 0-init state = result.get('state', Batch()) if state is None: state = Batch() self.data.state = state if hasattr(result, 'policy'): self.data.policy = to_numpy(result.policy) # save hidden state to policy._state, in order to save into buffer self.data.policy._state = self.data.state self.data.act = to_numpy(result.act) if self._action_noise is not None: self.data.act += self._action_noise(self.data.act.shape) # step in env obs_next, rew, done, info = self.env.step( self.data.act if self._multi_env else self.data.act[0]) # move data to self.data if not self._multi_env: obs_next = self._make_batch(obs_next) rew = self._make_batch(rew) done = self._make_batch(done) info = self._make_batch(info) self.data.obs_next = obs_next self.data.rew = rew self.data.done = done self.data.info = info if log_fn: log_fn(info if self._multi_env else info[0]) if render: self.render() if render > 0: time.sleep(render) # add data into the buffer self.length += 1 self.reward += self.data.rew if self.preprocess_fn: result = self.preprocess_fn(**self.data) self.data.update(result) if self._multi_env: # cache_buffer branch # change if self.data.done[0]: ty1_succ_rate_1 += self.data.info[0]['ty1_succ_rate_1'] ty1_succ_rate_2 += self.data.info[0]['ty1_succ_rate_2'] ty1_succ_rate_3 += self.data.info[0]['ty1_succ_rate_3'] Q_len_1 += self.data.info[0]['Q_len_1'] Q_len_2 += self.data.info[0]['Q_len_2'] Q_len_3 += self.data.info[0]['Q_len_3'] energy_effi_1 += self.data.info[0]['energy_effi_1'] energy_effi_2 += self.data.info[0]['energy_effi_2'] energy_effi_3 += self.data.info[0]['energy_effi_3'] avg_rate += self.data.info[0]['avg_rate'] avg_power += self.data.info[0]['avg_power'] for i in range(self.env_num): self._cached_buf[i].add(**self.data[i]) if self.data.done[i]: if n_step != 0 or np.isscalar(n_episode) or \ cur_episode[i] < n_episode[i]: cur_episode[i] += 1 reward_sum += self.reward[i] length_sum += self.length[i] if self._cached_buf: cur_step += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) self.reward[i], self.length[i] = 0., 0 if self._cached_buf: self._cached_buf[i].reset() self._reset_state(i) obs_next = self.data.obs_next if sum(self.data.done): env_ind = np.where(self.data.done)[0] obs_reset = self.env.reset(env_ind) if self.preprocess_fn: obs_next[env_ind] = self.preprocess_fn( obs=obs_reset).get('obs', obs_reset) else: obs_next[env_ind] = obs_reset self.data.obs_next = obs_next if n_episode != 0: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ cur_episode.sum() >= n_episode: break else: # single buffer, without cache_buffer if self.buffer is not None: self.buffer.add(**self.data[0]) cur_step += 1 if self.data.done[0]: # change ty1_succ_rate_1 += self.data.info['ty1_succ_rate_1'] ty1_succ_rate_2 += self.data.info['ty1_succ_rate_2'] ty1_succ_rate_3 += self.data.info['ty1_succ_rate_3'] Q_len_1 += self.data.info['Q_len_1'] Q_len_2 += self.data.info['Q_len_2'] Q_len_3 += self.data.info['Q_len_3'] energy_effi_1 += self.data.info['energy_effi_1'] energy_effi_2 += self.data.info['energy_effi_2'] energy_effi_3 += self.data.info['energy_effi_3'] avg_rate += self.data.info[0]['avg_rate'] avg_power += self.data.info[0]['avg_power'] cur_episode += 1 reward_sum += self.reward[0] length_sum += self.length[0] self.reward, self.length = 0., np.zeros(self.env_num) self.data.state = Batch() obs_next = self._make_batch(self.env.reset()) if self.preprocess_fn: obs_next = self.preprocess_fn(obs=obs_next).get( 'obs', obs_next) self.data.obs_next = obs_next if n_episode != 0 and cur_episode >= n_episode: break if n_step != 0 and cur_step >= n_step: break self.data.obs = self.data.obs_next self.data.obs = self.data.obs_next # generate the statistics cur_episode = sum(cur_episode) duration = max(time.time() - start_time, 1e-9) self.step_speed.add(cur_step / duration) self.episode_speed.add(cur_episode / duration) self.collect_step += cur_step self.collect_episode += cur_episode self.collect_time += duration if isinstance(n_episode, list): n_episode = np.sum(n_episode) else: n_episode = max(cur_episode, 1) reward_sum /= n_episode if np.asanyarray(reward_sum).size > 1: # non-scalar reward_sum reward_sum = self._rew_metric(reward_sum) # change return { 'n/ep': cur_episode, 'n/st': cur_step, 'v/st': self.step_speed.get(), 'v/ep': self.episode_speed.get(), 'rew': reward_sum, 'len': length_sum / n_episode, 'ty1s_1': ty1_succ_rate_1, 'ty1s_2': ty1_succ_rate_2, 'ty1s_3': ty1_succ_rate_3, 'ql_1': Q_len_1, 'ql_2': Q_len_2, 'ql_3': Q_len_3, 'ee_1': energy_effi_1, 'ee_2': energy_effi_2, 'ee_3': energy_effi_3, 'avg_r': avg_rate, 'avg_p': avg_power, } def sample(self, batch_size: int) -> Batch: """Sample a data batch from the internal replay buffer. It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size. """ batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data
class Collector(object): """docstring for Collector""" def __init__(self, policy, env, buffer=None, stat_size=100): super().__init__() self.env = env self.env_num = 1 self.collect_step = 0 self.collect_episode = 0 self.collect_time = 0 if buffer is None: self.buffer = ReplayBuffer(100) else: self.buffer = buffer self.policy = policy self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) self._multi_buf = False # True if buf is a list # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) if isinstance(self.buffer, list): assert len(self.buffer) == self.env_num, \ 'The number of data buffer does not match the number of ' \ 'input env.' self._multi_buf = True elif isinstance(self.buffer, ReplayBuffer): self._cached_buf = [ ListReplayBuffer() for _ in range(self.env_num) ] else: raise TypeError('The buffer in data collector is invalid!') self.reset_env() self.reset_buffer() # state over batch is either a list, an np.ndarray, or a torch.Tensor self.state = None self.step_speed = MovAvg(stat_size) self.episode_speed = MovAvg(stat_size) def reset_buffer(self): if self._multi_buf: for b in self.buffer: b.reset() else: self.buffer.reset() def get_env_num(self): return self.env_num def reset_env(self): self._obs = self.env.reset() self._act = self._rew = self._done = self._info = None if self._multi_env: self.reward = np.zeros(self.env_num) self.length = np.zeros(self.env_num) else: self.reward, self.length = 0, 0 for b in self._cached_buf: b.reset() def seed(self, seed=None): if hasattr(self.env, 'seed'): return self.env.seed(seed) def render(self, **kwargs): if hasattr(self.env, 'render'): return self.env.render(**kwargs) def close(self): if hasattr(self.env, 'close'): self.env.close() def _make_batch(self, data): if isinstance(data, np.ndarray): return data[None] else: return np.array([data]) def collect(self, n_step=0, n_episode=0, render=0): warning_count = 0 if not self._multi_env: n_episode = np.sum(n_episode) start_time = time.time() assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ "One and only one collection number specification permitted!" cur_step = 0 cur_episode = np.zeros(self.env_num) if self._multi_env else 0 reward_sum = 0 length_sum = 0 while True: if warning_count >= 100000: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) if self._multi_env: batch_data = Batch(obs=self._obs, act=self._act, rew=self._rew, done=self._done, obs_next=None, info=self._info) else: batch_data = Batch(obs=self._make_batch(self._obs), act=self._make_batch(self._act), rew=self._make_batch(self._rew), done=self._make_batch(self._done), obs_next=None, info=self._make_batch(self._info)) result = self.policy(batch_data, self.state) self.state = result.state if hasattr(result, 'state') else None if isinstance(result.act, torch.Tensor): self._act = result.act.detach().cpu().numpy() elif not isinstance(self._act, np.ndarray): self._act = np.array(result.act) else: self._act = result.act obs_next, self._rew, self._done, self._info = self.env.step( self._act if self._multi_env else self._act[0]) if render > 0: self.env.render() time.sleep(render) self.length += 1 self.reward += self._rew if self._multi_env: for i in range(self.env_num): data = { 'obs': self._obs[i], 'act': self._act[i], 'rew': self._rew[i], 'done': self._done[i], 'obs_next': obs_next[i], 'info': self._info[i] } if self._cached_buf: warning_count += 1 self._cached_buf[i].add(**data) elif self._multi_buf: warning_count += 1 self.buffer[i].add(**data) cur_step += 1 else: warning_count += 1 self.buffer.add(**data) cur_step += 1 if self._done[i]: if n_step != 0 or np.isscalar(n_episode) or \ cur_episode[i] < n_episode[i]: cur_episode[i] += 1 reward_sum += self.reward[i] length_sum += self.length[i] if self._cached_buf: cur_step += len(self._cached_buf[i]) self.buffer.update(self._cached_buf[i]) self.reward[i], self.length[i] = 0, 0 if self._cached_buf: self._cached_buf[i].reset() if isinstance(self.state, list): self.state[i] = None elif self.state is not None: if isinstance(self.state[i], dict): self.state[i] = {} else: self.state[i] = self.state[i] * 0 if isinstance(self.state, torch.Tensor): # remove ref count in pytorch (?) self.state = self.state.detach() if sum(self._done): obs_next = self.env.reset(np.where(self._done)[0]) if n_episode != 0: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ cur_episode.sum() >= n_episode: break else: self.buffer.add(self._obs, self._act[0], self._rew, self._done, obs_next, self._info) cur_step += 1 if self._done: cur_episode += 1 reward_sum += self.reward length_sum += self.length self.reward, self.length = 0, 0 self.state = None obs_next = self.env.reset() if n_episode != 0 and cur_episode >= n_episode: break if n_step != 0 and cur_step >= n_step: break self._obs = obs_next self._obs = obs_next if self._multi_env: cur_episode = sum(cur_episode) duration = time.time() - start_time self.step_speed.add(cur_step / duration) self.episode_speed.add(cur_episode / duration) self.collect_step += cur_step self.collect_episode += cur_episode self.collect_time += duration if isinstance(n_episode, list): n_episode = np.sum(n_episode) else: n_episode = max(cur_episode, 1) return { 'n/ep': cur_episode, 'n/st': cur_step, 'v/st': self.step_speed.get(), 'v/ep': self.episode_speed.get(), 'rew': reward_sum / n_episode, 'len': length_sum / n_episode, } def sample(self, batch_size): if self._multi_buf: if batch_size > 0: lens = [len(b) for b in self.buffer] total = sum(lens) batch_index = np.random.choice(total, batch_size, p=np.array(lens) / total) else: batch_index = np.array([]) batch_data = Batch() for i, b in enumerate(self.buffer): cur_batch = (batch_index == i).sum() if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) batch_data.append(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data
class Collector(object): """The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: an environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to ``None``, it will automatically assign a small-size :class:`~tianshou.data.ReplayBuffer`. :param int stat_size: for the moving average of recording speed, defaults to 100. :param bool store_obs_next: whether to store the obs_next to replay buffer, defaults to ``True``. Example: :: policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) buffers = [ReplayBuffer(size=5000) for _ in range(3)] # you can also pass a list of replay buffer to collector, for multi-env # collector = Collector(policy, envs, buffer=buffers) collector = Collector(policy, envs, buffer=replay_buffer) # collect at least 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) # sample data with a given number of batch-size: batch_data = collector.sample(batch_size=64) # policy.learn(batch_data) # btw, vanilla policy gradient only # supports on-policy training, so here we pick all data in the buffer batch_data = collector.sample(batch_size=0) policy.learn(batch_data) # on-policy algorithms use the collected data only once, so here we # clear the buffer collector.reset_buffer() For the scenario of collecting data from multiple environments to a single buffer, the cache buffers will turn on automatically. It may return the data more than the given limitation. .. note:: Please make sure the given environment has a time limitation. """ def __init__(self, policy, env, buffer=None, stat_size=100, store_obs_next=True, **kwargs): super().__init__() self.env = env self.env_num = 1 self.collect_step = 0 self.collect_episode = 0 self.collect_time = 0 if buffer is None: self.buffer = ReplayBuffer(100) else: self.buffer = buffer self.policy = policy self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) self._multi_buf = False # True if buf is a list # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) if isinstance(self.buffer, list): assert len(self.buffer) == self.env_num, \ 'The number of data buffer does not match the number of ' \ 'input env.' self._multi_buf = True elif isinstance(self.buffer, ReplayBuffer): self._cached_buf = [ ListReplayBuffer() for _ in range(self.env_num)] else: raise TypeError('The buffer in data collector is invalid!') self.reset_env() self.reset_buffer() # state over batch is either a list, an np.ndarray, or a torch.Tensor self.state = None self.step_speed = MovAvg(stat_size) self.episode_speed = MovAvg(stat_size) self._save_s_ = store_obs_next def reset_buffer(self): """Reset the main data buffer.""" if self._multi_buf: for b in self.buffer: b.reset() else: self.buffer.reset() def get_env_num(self): """Return the number of environments the collector has.""" return self.env_num def reset_env(self): """Reset all of the environment(s)' states and reset all of the cache buffers (if need). """ self._obs = self.env.reset() self._act = self._rew = self._done = self._info = None if self._multi_env: self.reward = np.zeros(self.env_num) self.length = np.zeros(self.env_num) else: self.reward, self.length = 0, 0 for b in self._cached_buf: b.reset() def seed(self, seed=None): """Reset all the seed(s) of the given environment(s).""" if hasattr(self.env, 'seed'): return self.env.seed(seed) def render(self, **kwargs): """Render all the environment(s).""" if hasattr(self.env, 'render'): return self.env.render(**kwargs) def close(self): """Close the environment(s).""" if hasattr(self.env, 'close'): self.env.close() def _make_batch(self, data): if isinstance(data, np.ndarray): return data[None] else: return np.array([data]) def collect(self, n_step=0, n_episode=0, render=0): """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect (in each environment). :type n_episode: int or list :param float render: the sleep time between rendering consecutive frames. No rendering if it is ``0`` (default option). .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. * ``v/st`` the speed of steps per second. * ``v/ep`` the speed of episode per second. * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ warning_count = 0 if not self._multi_env: n_episode = np.sum(n_episode) start_time = time.time() assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ "One and only one collection number specification is permitted!" cur_step = 0 cur_episode = np.zeros(self.env_num) if self._multi_env else 0 reward_sum = 0 length_sum = 0 while True: if warning_count >= 100000: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) if self._multi_env: batch_data = Batch( obs=self._obs, act=self._act, rew=self._rew, done=self._done, obs_next=None, info=self._info) else: batch_data = Batch( obs=self._make_batch(self._obs), act=self._make_batch(self._act), rew=self._make_batch(self._rew), done=self._make_batch(self._done), obs_next=None, info=self._make_batch(self._info)) with torch.no_grad(): result = self.policy(batch_data, self.state) self.state = result.state if hasattr(result, 'state') else None if isinstance(result.act, torch.Tensor): self._act = result.act.detach().cpu().numpy() elif not isinstance(self._act, np.ndarray): self._act = np.array(result.act) else: self._act = result.act obs_next, self._rew, self._done, self._info = self.env.step( self._act if self._multi_env else self._act[0]) if render > 0: self.env.render() time.sleep(render) self.length += 1 self.reward += self._rew if self._multi_env: for i in range(self.env_num): data = { 'obs': self._obs[i], 'act': self._act[i], 'rew': self._rew[i], 'done': self._done[i], 'obs_next': obs_next[i] if self._save_s_ else None, 'info': self._info[i]} if self._cached_buf: warning_count += 1 self._cached_buf[i].add(**data) elif self._multi_buf: warning_count += 1 self.buffer[i].add(**data) cur_step += 1 else: warning_count += 1 self.buffer.add(**data) cur_step += 1 if self._done[i]: if n_step != 0 or np.isscalar(n_episode) or \ cur_episode[i] < n_episode[i]: cur_episode[i] += 1 reward_sum += self.reward[i] length_sum += self.length[i] if self._cached_buf: cur_step += len(self._cached_buf[i]) self.buffer.update(self._cached_buf[i]) self.reward[i], self.length[i] = 0, 0 if self._cached_buf: self._cached_buf[i].reset() if isinstance(self.state, list): self.state[i] = None elif self.state is not None: if isinstance(self.state[i], dict): self.state[i] = {} else: self.state[i] = self.state[i] * 0 if isinstance(self.state, torch.Tensor): # remove ref count in pytorch (?) self.state = self.state.detach() if sum(self._done): obs_next = self.env.reset(np.where(self._done)[0]) if n_episode != 0: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ cur_episode.sum() >= n_episode: break else: self.buffer.add( self._obs, self._act[0], self._rew, self._done, obs_next if self._save_s_ else None, self._info) cur_step += 1 if self._done: cur_episode += 1 reward_sum += self.reward length_sum += self.length self.reward, self.length = 0, 0 self.state = None obs_next = self.env.reset() if n_episode != 0 and cur_episode >= n_episode: break if n_step != 0 and cur_step >= n_step: break self._obs = obs_next self._obs = obs_next if self._multi_env: cur_episode = sum(cur_episode) duration = time.time() - start_time self.step_speed.add(cur_step / duration) self.episode_speed.add(cur_episode / duration) self.collect_step += cur_step self.collect_episode += cur_episode self.collect_time += duration if isinstance(n_episode, list): n_episode = np.sum(n_episode) else: n_episode = max(cur_episode, 1) return { 'n/ep': cur_episode, 'n/st': cur_step, 'v/st': self.step_speed.get(), 'v/ep': self.episode_speed.get(), 'rew': reward_sum / n_episode, 'len': length_sum / n_episode, } def sample(self, batch_size): """Sample a data batch from the internal replay buffer. It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size. """ if self._multi_buf: if batch_size > 0: lens = [len(b) for b in self.buffer] total = sum(lens) batch_index = np.random.choice( total, batch_size, p=np.array(lens) / total) else: batch_index = np.array([]) batch_data = Batch() for i, b in enumerate(self.buffer): cur_batch = (batch_index == i).sum() if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) batch_data.append(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data