def test_batch(): batch = Batch(obs=[0], np=np.zeros([3, 4])) batch.update(obs=[1]) assert batch.obs == [1] batch.append(batch) assert batch.obs == [1, 1] assert batch.np.shape == (6, 4) assert batch[0].obs == batch[1].obs with pytest.raises(IndexError): batch[2] batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, permute=False)): assert b.obs == batch[i].obs
def test_batch_over_batch(): batch = Batch(a=[3, 4, 5], b=[4, 5, 6]) batch2 = Batch({'c': [6, 7, 8], 'b': batch}) batch2.b.b[-1] = 0 print(batch2) for k, v in batch2.items(): assert np.all(batch2[k] == v) assert batch2[-1].b.b == 0 batch2.cat_(Batch(c=[6, 7, 8], b=batch)) assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) batch2.update(batch2.b, six=[6, 6, 6]) assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0]) assert np.allclose(batch2.six, [6, 6, 6]) d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) batch3.cat_(Batch(c=[6, 7, 8], b=d)) assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6]) batch4 = Batch(({'a': {'b': np.array([1.0])}},)) assert batch4.a.b.ndim == 2 assert batch4.a.b[0, 0] == 1.0 # advanced slicing batch5 = Batch(a=[[1, 2]], b={'c': np.zeros([3, 2, 1])}) assert batch5.shape == [1, 2] with pytest.raises(IndexError): batch5[2] with pytest.raises(IndexError): batch5[:, 3] with pytest.raises(IndexError): batch5[:, :, -1] batch5[:, -1] += 1 assert np.allclose(batch5.a, [1, 3]) assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) with pytest.raises(ValueError): batch5[:, -1] = 1 batch5[:, 0] = {'a': -1} assert np.allclose(batch5.a, [-1, 3]) assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
class base_attack_collector: """ :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param obs_adv_atk: an instance of the :class:`~advertorch.attacks.base.Attack` class implementing an image adversarial attack. :param perfect_attack: force adversarial attacks on observations to be always effective (ignore the ``adv`` param). """ def __init__(self, policy: BasePolicy, env: gym.Env, obs_adv_atk: Attack, perfect_attack: bool = False, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): self.device = device self.policy = policy self.env = env self.obs_adv_atk = obs_adv_atk self.perfect_attack = perfect_attack self.action_space = self.env.action_space.shape or self.env.action_space.n self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={}) self.reset_env() self.episode_count = 0 # current number of episodes self.reward_total = 0. # total episode cumulative reward self.frames_count = 0 # number of observed frames self.n_attacks = 0 # number of attacks performed self.succ_attacks = 0 # number of successful image attacks self.start_time = 0 # time when the attack starts def reset_env(self): self.data.obs = self.env.reset() def render(self, **kwargs) -> None: return self.env.render(**kwargs) def reset_attack(self): self.episode_count, self.reward_total, self.frames_count,\ self.n_attacks, self.succ_attacks = 0, 0, 0, 0, 0 self.start_time = time.time() def get_attack_stats(self) -> Dict[str, float]: duration = max(time.time() - self.start_time, 1e-9) if self.episode_count == 0: self.episode_count = 1 return { 'n/ep': self.episode_count, 'n/st': self.frames_count, 'v/st': self.frames_count / duration, 'v/ep': self.episode_count / duration, 'rew': self.reward_total / self.episode_count, 'len': self.frames_count / self.episode_count, 'n_atks': self.n_attacks / self.episode_count, 'n_succ_atks': self.succ_attacks / self.episode_count, 'atk_rate(%)': self.n_attacks / self.frames_count, 'succ_atks(%)': self.succ_attacks / self.n_attacks if self.n_attacks > 0 else 0, } def show_warning(self): if self.frames_count >= 100000 and self.episode_count == 0: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) def check_end_attack(self, n_step, n_episode) -> bool: """Returns True when the attack terminates""" if n_step: if self.frames_count >= n_step: return True if n_episode: if self.episode_count >= n_episode: return True return False def perform_step(self): """ Performs action 'self.data.act' on 'self.env' and store the next observation in 'self.data.obs' """ obs_next, rew, done, info = self.env.step(self.data.act[0]) self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) self.reward_total += rew if self.data.done: self.episode_count += 1 self.reset_env() self.data.obs = self.data.obs_next def predict_next_action(self): """ Predicts the next action given observation 'self.data.obs' and policy 'self.policy', and stores it in 'self.data.act' :return: outcome of policy forward pass """ with torch.no_grad(): self.data.obs = np.expand_dims(self.data.obs, axis=0) result = self.policy(self.data, last_state=None) self.data.act = to_numpy(result.act) return result def obs_attacks( self, target_action: Optional[List[int]] = None, ): """ Performs an image adversarial attack on the observation stored in 'self.data.obs' respect to the action 'target_action' using the method defined in 'self.obs_adv_atk' :param target_action: - if obs_adv_atk.targeted=False, then 'target_action' must be the normal action. - if obs_adv_atk.targeted=True, then 'target_action' must be the adversarial action. """ if not target_action: target_action = self.data.act obs = torch.FloatTensor(self.data.obs).to( self.device) # convert observation to tensor act = torch.tensor(target_action).to( self.device) # convert action to tensor adv_obs = self.obs_adv_atk.perturb( obs, act) # create adversarial observation with torch.no_grad(): data = copy.deepcopy(self.data) data.obs = adv_obs.cpu().detach().numpy() result = self.policy(data, last_state=None) self.data.act = to_numpy(result.act) def collect(self, n_step: int = 0, n_episode: int = 0, render: Optional[float] = None) -> Dict[str, float]: """ :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect. :param float render: the sleep time between rendering consecutive frames, defaults to ``None`` (no rendering). :return: A dict including the following keys * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. * ``v/st`` the speed of steps per second. * ``v/ep`` the speed of episode per second. * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. * ``n_attacks`` number of performed attacks. * ``n_succ_attacks`` number of performed successful attacks. * ``n_attacks(%)`` ratio of performed attacks over steps. * ``succ_atks(%)`` ratio of successful attacks over performed attacks. """ error = "Sub-classes must implement 'collect'." raise NotImplementedError(error)
class Collector(object): """Collector enables the policy to interact with different types of envs with \ exact number of steps or episodes. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to None, it will not store the data. Default to None. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None. :param bool exploration_noise: determine whether the action needs to be modified with corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. Default to False. The "preprocess_fn" is a function called before the data has been added to the buffer with batch format. It will receive only "obs" and "env_id" when the collector resets the environment, and will receive six keys "obs_next", "rew", "done", "info", "policy" and "env_id" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in "test/base/test_collector.py". .. note:: Please make sure the given environment has a time limitation if using n_episode collect option. """ def __init__( self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, exploration_noise: bool = False, ) -> None: super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn( "Single environment detected, wrap to DummyVectorEnv.") env = DummyVectorEnv([lambda: env]) self.env = env self.env_num = len(env) self.exploration_noise = exploration_noise self._assign_buffer(buffer) self.policy = policy self.preprocess_fn = preprocess_fn self._action_space = env.action_space # avoid creating attribute outside __init__ self.reset() def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: """Check if the buffer matches the constraint.""" if buffer is None: buffer = VectorReplayBuffer(self.env_num, self.env_num) elif isinstance(buffer, ReplayBufferManager): assert buffer.buffer_num >= self.env_num if isinstance(buffer, CachedReplayBuffer): assert buffer.cached_buffer_num >= self.env_num else: # ReplayBuffer or PrioritizedReplayBuffer assert buffer.maxsize > 0 if self.env_num > 1: if type(buffer) == ReplayBuffer: buffer_type = "ReplayBuffer" vector_type = "VectorReplayBuffer" else: buffer_type = "PrioritizedReplayBuffer" vector_type = "PrioritizedVectorReplayBuffer" raise TypeError( f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect " f"{self.env_num} envs,\n\tplease use {vector_type}(total_size=" f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead." ) self.buffer = buffer def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy self.data = Batch(obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}) self.reset_env() self.reset_buffer() self.reset_stat() def reset_stat(self) -> None: """Reset the statistic variables.""" self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 def reset_buffer(self, keep_statistics: bool = False) -> None: """Reset the data buffer.""" self.buffer.reset(keep_statistics=keep_statistics) def reset_env(self) -> None: """Reset all of the environments.""" obs = self.env.reset() if self.preprocess_fn: obs = self.preprocess_fn(obs=obs, env_id=np.arange(self.env_num)).get( "obs", obs) self.data.obs = obs def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset the hidden state: self.data.state[id].""" if hasattr(self.data.policy, "hidden_state"): state = self.data.policy.hidden_state # it is a reference if isinstance(state, torch.Tensor): state[id].zero_() elif isinstance(state, np.ndarray): state[id] = None if state.dtype == object else 0 elif isinstance(state, Batch): state.empty_(id) def collect( self, n_step: Optional[int] = None, n_episode: Optional[int] = None, random: bool = False, render: Optional[float] = None, no_grad: bool = True, ) -> Dict[str, Any]: """Collect a specified number of step or episode. To ensure unbiased sampling result with n_episode option, this function will first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` episodes, they will be collected evenly from each env. :param int n_step: how many steps you want to collect. :param int n_episode: how many episodes you want to collect. :param bool random: whether to use random policy for collecting data. Default to False. :param float render: the sleep time between rendering consecutive frames. Default to None (no rendering). :param bool no_grad: whether to retain gradient in policy.forward(). Default to True (no gradient retaining). .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys * ``n/ep`` collected number of episodes. * ``n/st`` collected number of steps. * ``rews`` array of episode reward over collected episodes. * ``lens`` array of episode length over collected episodes. * ``idxs`` array of episode start index in buffer over collected episodes. """ assert not self.env.is_async, "Please use AsyncCollector if using async venv." if n_step is not None: assert n_episode is None, ( f"Only one of n_step or n_episode is allowed in Collector." f"collect, got n_step={n_step}, n_episode={n_episode}.") assert n_step > 0 if not n_step % self.env_num == 0: warnings.warn( f"n_step={n_step} is not a multiple of #env ({self.env_num}), " "which may cause extra transitions collected into the buffer." ) ready_env_ids = np.arange(self.env_num) elif n_episode is not None: assert n_episode > 0 ready_env_ids = np.arange(min(self.env_num, n_episode)) self.data = self.data[:min(self.env_num, n_episode)] else: raise TypeError( "Please specify at least one (either n_step or n_episode) " "in AsyncCollector.collect().") start_time = time.time() step_count = 0 episode_count = 0 episode_rews = [] episode_lens = [] episode_start_indices = [] while True: assert len(self.data) == len(ready_env_ids) # restore the state: if the last state is None, it won't store last_state = self.data.policy.pop("hidden_state", None) # get the next action if random: self.data.update(act=[ self._action_space[i].sample() for i in ready_env_ids ]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version # self.data.obs will be used by agent to get result result = self.policy(self.data, last_state) else: result = self.policy(self.data, last_state) # update state / act / policy into self.data policy = result.get("policy", Batch()) assert isinstance(policy, Batch) state = result.get("state", None) if state is not None: policy.hidden_state = state # save state into buffer act = to_numpy(result.act) if self.exploration_noise: act = self.policy.exploration_noise(act, self.data) self.data.update(policy=policy, act=act) # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env result = self.env.step(action_remap, ready_env_ids) # type: ignore obs_next, rew, done, info = result self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: self.data.update( self.preprocess_fn( obs_next=self.data.obs_next, rew=self.data.rew, done=self.data.done, info=self.data.info, policy=self.data.policy, env_id=ready_env_ids, )) if render: self.env.render() if render > 0 and not np.isclose(render, 0): time.sleep(render) # add data into the buffer ptr, ep_rew, ep_len, ep_idx = self.buffer.add( self.data, buffer_ids=ready_env_ids) # collect statistics step_count += len(ready_env_ids) if np.any(done): env_ind_local = np.where(done)[0] env_ind_global = ready_env_ids[env_ind_local] episode_count += len(env_ind_local) episode_lens.append(ep_len[env_ind_local]) episode_rews.append(ep_rew[env_ind_local]) episode_start_indices.append(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_reset = self.preprocess_fn(obs=obs_reset, env_id=env_ind_global).get( "obs", obs_reset) self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) # remove surplus env id from ready_env_ids # to avoid bias in selecting environments if n_episode: surplus_env_num = len(ready_env_ids) - (n_episode - episode_count) if surplus_env_num > 0: mask = np.ones_like(ready_env_ids, dtype=bool) mask[env_ind_local[:surplus_env_num]] = False ready_env_ids = ready_env_ids[mask] self.data = self.data[mask] self.data.obs = self.data.obs_next if (n_step and step_count >= n_step) or \ (n_episode and episode_count >= n_episode): break # generate statistics self.collect_step += step_count self.collect_episode += episode_count self.collect_time += max(time.time() - start_time, 1e-9) if n_episode: self.data = Batch(obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}) self.reset_env() if episode_count > 0: rews, lens, idxs = list( map(np.concatenate, [episode_rews, episode_lens, episode_start_indices])) else: rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) return { "n/ep": episode_count, "n/st": step_count, "rews": rews, "lens": lens, "idxs": idxs, }
def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() assert not Batch(b={'c': {}}).is_empty() assert Batch(b={'c': {}}).is_empty(recurse=True) assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty() assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) assert not Batch(d=1).is_empty() assert not Batch(a=np.float64(1.0)).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None]) assert b.c.dtype == object b = Batch(d=[None], e=[starmap], f=Batch) assert b.d.dtype == b.e.dtype == object and b.f == Batch b = Batch() b.update() assert b.is_empty() b.update(c=[3, 5]) assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) assert 'a' in b and b.a == 3 assert b.pop('a') == 3 assert 'a' not in b with pytest.raises(AssertionError): Batch({1: 2}) assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object with pytest.raises(TypeError): Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch.cat_(batch) assert torch.allclose(batch.a, torch.ones(4, 3)) Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.cat_(batch) assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) assert np.allclose(batch[0].obs, batch[1].obs) batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: assert b.obs == batch[i].obs else: with pytest.raises(AttributeError): batch[i].obs with pytest.raises(AttributeError): b.obs print(batch) batch = Batch(a=np.arange(10)) with pytest.raises(AssertionError): list(batch.split(0)) data = [ (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]), (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]), (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]), (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), ] for size, merge_last, result in data: bs = list(batch.split(size, shuffle=False, merge_last=merge_last)) assert [bs[i].a.tolist() for i in range(len(bs))] == result batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) assert batch_item.a.b == batch_dict['b'] assert isinstance(batch_item.a.c, float) assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] batch2 = Batch(a=[{ 'b': np.float64(1.0), 'c': np.zeros(1), 'd': Batch(e=np.array(3.0))}]) assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] assert Batch(a=set((1, 2, 1))).shape == [] assert batch2.shape[0] == 1 assert 'a' in batch2 and all([i in batch2.a for i in 'bcd']) with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): batch2[1] assert batch2[0].shape == [] with pytest.raises(IndexError): batch2[0][0] with pytest.raises(TypeError): len(batch2[0]) assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.d.e, np.float64) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch([e for e in batch2]) assert batch2_from_list.a.b == batch2.a.b assert batch2_from_list.a.c == batch2.a.c assert batch2_from_list.a.d.e == batch2.a.d.e assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2.a.d.f = {} batch2_sum = (batch2 + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 assert batch2_sum.a.d.f.is_empty() with pytest.raises(TypeError): batch2 += [1] batch3 = Batch(a={ 'c': np.zeros(1), 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))}) batch3.a.d[0] = {'e': 4.0} assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 with pytest.raises(ValueError): batch3.a.d[0] = Batch(f=5.0, g=0.0) with pytest.raises(ValueError): batch3[0] = Batch(a={"c": 2, "e": 1}) # auto convert batch4 = Batch(a=np.array(['a', 'b'])) assert batch4.a.dtype == object # auto convert to object batch4.update(a=np.array(['c', 'd'])) assert list(batch4.a) == ['c', 'd'] assert batch4.a.dtype == object # auto convert to object batch5 = Batch(a=np.array([{'index': 0}])) assert isinstance(batch5.a, Batch) assert np.allclose(batch5.a.index, [0]) batch5.b = np.array([{'index': 1}]) assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) # None is a valid object and can be stored in Batch a = Batch.stack([Batch(a=None), Batch(b=None)]) assert a.a[0] is None and a.a[1] is None assert a.b[0] is None and a.b[1] is None # nx.Graph corner case assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object g1 = nx.Graph() g1.add_nodes_from(list(range(10))) g2 = nx.Graph() g2.add_nodes_from(list(range(20))) assert Batch(a=np.array([g1, g2])).a.dtype == object
class Collector(object): """The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to ``None``, it will automatically assign a small-size :class:`~tianshou.data.ReplayBuffer`. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults to ``None``. :param int stat_size: for the moving average of recording speed, defaults to 100. :param BaseNoise action_noise: add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose. :param function reward_metric: to be used in multi-agent RL. The reward to report is of shape [agent_num], but we need to return a single scalar to monitor training. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. By default, the behavior is to select the reward of agent 1. The ``preprocess_fn`` is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed in :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the collector resets the environment. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in "test/base/test_collector.py". Example: :: policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) buffers = [ReplayBuffer(size=5000) for _ in range(3)] # you can also pass a list of replay buffer to collector, for multi-env # collector = Collector(policy, envs, buffer=buffers) collector = Collector(policy, envs, buffer=replay_buffer) # collect at least 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) # sample data with a given number of batch-size: batch_data = collector.sample(batch_size=64) # policy.learn(batch_data) # btw, vanilla policy gradient only # supports on-policy training, so here we pick all data in the buffer batch_data = collector.sample(batch_size=0) policy.learn(batch_data) # on-policy algorithms use the collected data only once, so here we # clear the buffer collector.reset_buffer() For the scenario of collecting data from multiple environments to a single buffer, the cache buffers will turn on automatically. It may return the data more than the given limitation. .. note:: Please make sure the given environment has a time limitation. """ def __init__( self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, stat_size: Optional[int] = 100, action_noise: Optional[BaseNoise] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: super().__init__() self.env = env self.env_num = 1 self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) self._cached_buf = [ ListReplayBuffer() for _ in range(self.env_num) ] self.stat_size = stat_size self._action_noise = action_noise self._rew_metric = reward_metric or Collector._default_rew_metric self.reset() @staticmethod def _default_rew_metric(x): # this internal function is designed for single-agent RL # for multi-agent RL, a reward_metric must be provided assert np.asanyarray(x).size == 1, \ 'Please specify the reward_metric ' \ 'since the reward is not a scalar.' return x def reset(self) -> None: """Reset all related variables in the collector.""" self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={}) self.reset_env() self.reset_buffer() self.step_speed = MovAvg(self.stat_size) self.episode_speed = MovAvg(self.stat_size) self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 if self._action_noise is not None: self._action_noise.reset() def reset_buffer(self) -> None: """Reset the main data buffer.""" if self.buffer is not None: self.buffer.reset() def get_env_num(self) -> int: """Return the number of environments the collector have.""" return self.env_num def reset_env(self) -> None: """Reset all of the environment(s)' states and reset all of the cache buffers (if need). """ obs = self.env.reset() if not self._multi_env: obs = self._make_batch(obs) if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get('obs', obs) self.data.obs = obs self.reward = 0. # will be specified when the first data is ready self.length = np.zeros(self.env_num) for b in self._cached_buf: b.reset() def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: """Reset all the seed(s) of the given environment(s).""" return self.env.seed(seed) def render(self, **kwargs) -> None: """Render all the environment(s).""" return self.env.render(**kwargs) def close(self) -> None: """Close the environment(s).""" self.env.close() def _make_batch(self, data: Any) -> np.ndarray: """Return [data].""" if isinstance(data, np.ndarray): return data[None] else: return np.array([data]) def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset self.data.state[id].""" state = self.data.state # it is a reference if isinstance(state, torch.Tensor): state[id].zero_() elif isinstance(state, np.ndarray): state[id] = None if state.dtype == np.object else 0 elif isinstance(state, Batch): state.empty_(id) def collect( self, n_step: int = 0, n_episode: Union[int, List[int]] = 0, random: bool = False, render: Optional[float] = None, log_fn: Optional[Callable[[dict], None]] = None) -> Dict[str, float]: """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect (in each environment). :type n_episode: int or list :param bool random: whether to use random policy for collecting data, defaults to ``False``. :param float render: the sleep time between rendering consecutive frames, defaults to ``None`` (no rendering). :param function log_fn: a function which receives env info, typically for tensorboard logging. .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. * ``v/st`` the speed of steps per second. * ``v/ep`` the speed of episode per second. * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ if not self._multi_env: n_episode = np.sum(n_episode) start_time = time.time() assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ "One and only one collection number specification is permitted!" cur_step, cur_episode = 0, np.zeros(self.env_num) reward_sum, length_sum = 0., 0 # change ty1_succ_rate_1 = 0. ty1_succ_rate_2 = 0. ty1_succ_rate_3 = 0. ty1_succ_rate_4 = 0. Q_len_1 = 0. Q_len_2 = 0. Q_len_3 = 0. Q_len_4 = 0. energy_effi_1 = 0. energy_effi_2 = 0. energy_effi_3 = 0. energy_effi_4 = 0. avg_rate = 0. avg_power = 0. while True: if cur_step >= 100000 and cur_episode.sum() == 0: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) # restore the state and the input data last_state = self.data.state if last_state.is_empty(): last_state = None self.data.update(state=Batch(), obs_next=Batch(), policy=Batch()) # calculate the next action if random: action_space = self.env.action_space if isinstance(action_space, list): result = Batch(act=[a.sample() for a in action_space]) else: result = Batch(act=self._make_batch(action_space.sample())) else: with torch.no_grad(): result = self.policy(self.data, last_state) # convert None to Batch(), since None is reserved for 0-init state = result.get('state', Batch()) if state is None: state = Batch() self.data.state = state if hasattr(result, 'policy'): self.data.policy = to_numpy(result.policy) # save hidden state to policy._state, in order to save into buffer self.data.policy._state = self.data.state self.data.act = to_numpy(result.act) if self._action_noise is not None: self.data.act += self._action_noise(self.data.act.shape) # step in env obs_next, rew, done, info = self.env.step( self.data.act if self._multi_env else self.data.act[0]) # move data to self.data if not self._multi_env: obs_next = self._make_batch(obs_next) rew = self._make_batch(rew) done = self._make_batch(done) info = self._make_batch(info) self.data.obs_next = obs_next self.data.rew = rew self.data.done = done self.data.info = info if log_fn: log_fn(info if self._multi_env else info[0]) if render: self.render() if render > 0: time.sleep(render) # add data into the buffer self.length += 1 self.reward += self.data.rew if self.preprocess_fn: result = self.preprocess_fn(**self.data) self.data.update(result) if self._multi_env: # cache_buffer branch # change if self.data.done[0]: ty1_succ_rate_1 += self.data.info[0]['ty1_succ_rate_1'] ty1_succ_rate_2 += self.data.info[0]['ty1_succ_rate_2'] ty1_succ_rate_3 += self.data.info[0]['ty1_succ_rate_3'] ty1_succ_rate_4 += self.data.info[0]['ty1_succ_rate_4'] Q_len_1 += self.data.info[0]['Q_len_1'] Q_len_2 += self.data.info[0]['Q_len_2'] Q_len_3 += self.data.info[0]['Q_len_3'] Q_len_4 += self.data.info[0]['Q_len_4'] energy_effi_1 += self.data.info[0]['energy_effi_1'] energy_effi_2 += self.data.info[0]['energy_effi_2'] energy_effi_3 += self.data.info[0]['energy_effi_3'] energy_effi_4 += self.data.info[0]['energy_effi_4'] avg_rate += self.data.info[0]['avg_rate'] avg_power += self.data.info[0]['avg_power'] for i in range(self.env_num): self._cached_buf[i].add(**self.data[i]) if self.data.done[i]: if n_step != 0 or np.isscalar(n_episode) or \ cur_episode[i] < n_episode[i]: cur_episode[i] += 1 reward_sum += self.reward[i] length_sum += self.length[i] if self._cached_buf: cur_step += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) self.reward[i], self.length[i] = 0., 0 if self._cached_buf: self._cached_buf[i].reset() self._reset_state(i) obs_next = self.data.obs_next if sum(self.data.done): env_ind = np.where(self.data.done)[0] obs_reset = self.env.reset(env_ind) if self.preprocess_fn: obs_next[env_ind] = self.preprocess_fn( obs=obs_reset).get('obs', obs_reset) else: obs_next[env_ind] = obs_reset self.data.obs_next = obs_next if n_episode != 0: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ cur_episode.sum() >= n_episode: break else: # single buffer, without cache_buffer if self.buffer is not None: self.buffer.add(**self.data[0]) cur_step += 1 if self.data.done[0]: # change ty1_succ_rate_1 += self.data.info['ty1_succ_rate_1'] ty1_succ_rate_2 += self.data.info['ty1_succ_rate_2'] ty1_succ_rate_3 += self.data.info['ty1_succ_rate_3'] ty1_succ_rate_4 += self.data.info['ty1_succ_rate_4'] Q_len_1 += self.data.info['Q_len_1'] Q_len_2 += self.data.info['Q_len_2'] Q_len_3 += self.data.info['Q_len_3'] Q_len_4 += self.data.info['Q_len_4'] energy_effi_1 += self.data.info['energy_effi_1'] energy_effi_2 += self.data.info['energy_effi_2'] energy_effi_3 += self.data.info['energy_effi_3'] energy_effi_4 += self.data.info['energy_effi_4'] avg_rate += self.data.info[0]['avg_rate'] avg_power += self.data.info[0]['avg_power'] cur_episode += 1 reward_sum += self.reward[0] length_sum += self.length[0] self.reward, self.length = 0., np.zeros(self.env_num) self.data.state = Batch() obs_next = self._make_batch(self.env.reset()) if self.preprocess_fn: obs_next = self.preprocess_fn(obs=obs_next).get( 'obs', obs_next) self.data.obs_next = obs_next if n_episode != 0 and cur_episode >= n_episode: break if n_step != 0 and cur_step >= n_step: break self.data.obs = self.data.obs_next self.data.obs = self.data.obs_next # generate the statistics cur_episode = sum(cur_episode) duration = max(time.time() - start_time, 1e-9) self.step_speed.add(cur_step / duration) self.episode_speed.add(cur_episode / duration) self.collect_step += cur_step self.collect_episode += cur_episode self.collect_time += duration if isinstance(n_episode, list): n_episode = np.sum(n_episode) else: n_episode = max(cur_episode, 1) reward_sum /= n_episode if np.asanyarray(reward_sum).size > 1: # non-scalar reward_sum reward_sum = self._rew_metric(reward_sum) # change return { 'n/ep': cur_episode, 'n/st': cur_step, 'v/st': self.step_speed.get(), 'v/ep': self.episode_speed.get(), 'rew': reward_sum, 'len': length_sum / n_episode, 'ty1s_1': ty1_succ_rate_1, 'ty1s_2': ty1_succ_rate_2, 'ty1s_3': ty1_succ_rate_3, 'ty1s_4': ty1_succ_rate_4, 'ql_1': Q_len_1, 'ql_2': Q_len_2, 'ql_3': Q_len_3, 'ql_4': Q_len_4, 'ee_1': energy_effi_1, 'ee_2': energy_effi_2, 'ee_3': energy_effi_3, 'ee_4': energy_effi_4, 'avg_r': avg_rate, 'avg_p': avg_power, } def sample(self, batch_size: int) -> Batch: """Sample a data batch from the internal replay buffer. It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size. """ batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data
class Collector(object): """Collector enables the policy to interact with different types of envs. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to ``None`` (testing phase), it will not store the data. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults to None. :param BaseNoise action_noise: add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose. :param function reward_metric: to be used in multi-agent RL. The reward to report is of shape [agent_num], but we need to return a single scalar to monitor training. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. By default, the behavior is to select the reward of agent 1. The ``preprocess_fn`` is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed in :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the collector resets the environment. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in "test/base/test_collector.py". Here is the example: :: policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) collector = Collector(policy, envs, buffer=replay_buffer) # collect 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) Collected data always consist of full episodes. So if only ``n_step`` argument is give, the collector may return the data more than the ``n_step`` limitation. Same as ``n_episode`` for the multiple environment case. .. note:: Please make sure the given environment has a time limitation. """ def __init__( self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, action_noise: Optional[BaseNoise] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) self.env = env self.env_num = len(env) # environments that are available in step() # this means all environments in synchronous simulation # but only a subset of environments in asynchronous simulation self._ready_env_ids = np.arange(self.env_num) # self.async is a flag to indicate whether this collector works # with asynchronous simulation self.is_async = env.is_async # need cache buffers before storing in the main buffer self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn self.process_fn = policy.process_fn self._action_space = env.action_space self._action_noise = action_noise self._rew_metric = reward_metric or Collector._default_rew_metric # avoid creating attribute outside __init__ self.reset() @staticmethod def _default_rew_metric( x: Union[Number, np.number] ) -> Union[Number, np.number]: # this internal function is designed for single-agent RL # for multi-agent RL, a reward_metric must be provided assert np.asanyarray(x).size == 1, ( "Please specify the reward_metric " "since the reward is not a scalar." ) return x def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for ``state`` so that ``self.data`` supports slicing # convert empty Batch to None when passing data to policy self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={}) self.reset_env() self.reset_buffer() self.reset_stat() if self._action_noise is not None: self._action_noise.reset() def reset_stat(self) -> None: """Reset the statistic variables.""" self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0 def reset_buffer(self) -> None: """Reset the main data buffer.""" if self.buffer is not None: self.buffer.reset() def get_env_num(self) -> int: """Return the number of environments the collector have.""" return self.env_num def reset_env(self) -> None: """Reset all of the environment(s)' states and the cache buffers.""" self._ready_env_ids = np.arange(self.env_num) obs = self.env.reset() if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) self.data.obs = obs for b in self._cached_buf: b.reset() def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset the hidden state: self.data.state[id].""" state = self.data.state # it is a reference if isinstance(state, torch.Tensor): state[id].zero_() elif isinstance(state, np.ndarray): state[id] = None if state.dtype == np.object else 0 elif isinstance(state, Batch): state.empty_(id) def collect( self, n_step: Optional[int] = None, n_episode: Optional[Union[int, List[int]]] = None, random: bool = False, render: Optional[float] = None, no_grad: bool = True, ) -> Dict[str, float]: """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect. If it is an int, it means to collect at lease ``n_episode`` episodes; if it is a list, it means to collect exactly ``n_episode[i]`` episodes in the i-th environment :param bool random: whether to use random policy for collecting data, defaults to False. :param float render: the sleep time between rendering consecutive frames, defaults to None (no rendering). :param bool no_grad: whether to retain gradient in policy.forward, defaults to True (no gradient retaining). .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. * ``v/st`` the speed of steps per second. * ``v/ep`` the speed of episode per second. * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ assert (n_step is not None and n_episode is None and n_step > 0) or ( n_step is None and n_episode is not None and np.sum(n_episode) > 0 ), "Only one of n_step or n_episode is allowed in Collector.collect, " f"got n_step = {n_step}, n_episode = {n_episode}." start_time = time.time() step_count = 0 # episode of each environment episode_count = np.zeros(self.env_num) # If n_episode is a list, and some envs have collected the required # number of episodes, these envs will be recorded in this list, and # they will not be stepped. finished_env_ids = [] rewards = [] whole_data = Batch() if isinstance(n_episode, list): assert len(n_episode) == self.get_env_num() finished_env_ids = [ i for i in self._ready_env_ids if n_episode[i] <= 0] self._ready_env_ids = np.array( [x for x in self._ready_env_ids if x not in finished_env_ids]) while True: if step_count >= 100000 and episode_count.sum() == 0: warnings.warn( "There are already many steps in an episode. " "You should add a time limitation to your environment!", Warning) is_async = self.is_async or len(finished_env_ids) > 0 if is_async: # self.data are the data for all environments in async # simulation or some envs have finished, # **only a subset of data are disposed**, # so we store the whole data in ``whole_data``, let self.data # to be the data available in ready environments, and finally # set these back into all the data whole_data = self.data self.data = self.data[self._ready_env_ids] # restore the state and the input data last_state = self.data.state if isinstance(last_state, Batch) and last_state.is_empty(): last_state = None self.data.update(state=Batch(), obs_next=Batch(), policy=Batch()) # calculate the next action if random: spaces = self._action_space result = Batch( act=[spaces[i].sample() for i in self._ready_env_ids]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version result = self.policy(self.data, last_state) else: result = self.policy(self.data, last_state) state = result.get("state", Batch()) # convert None to Batch(), since None is reserved for 0-init if state is None: state = Batch() self.data.update(state=state, policy=result.get("policy", Batch())) # save hidden state to policy._state, in order to save into buffer if not (isinstance(state, Batch) and state.is_empty()): self.data.policy._state = self.data.state self.data.act = to_numpy(result.act) if self._action_noise is not None: assert isinstance(self.data.act, np.ndarray) self.data.act += self._action_noise(self.data.act.shape) # step in env if not is_async: obs_next, rew, done, info = self.env.step(self.data.act) else: # store computed actions, states, etc _batch_set_item( whole_data, self._ready_env_ids, self.data, self.env_num) # fetch finished data obs_next, rew, done, info = self.env.step( self.data.act, id=self._ready_env_ids) self._ready_env_ids = np.array([i["env_id"] for i in info]) # get the stepped data self.data = whole_data[self._ready_env_ids] # move data to self.data self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if render: self.env.render() time.sleep(render) # add data into the buffer if self.preprocess_fn: result = self.preprocess_fn(**self.data) # type: ignore self.data.update(result) for j, i in enumerate(self._ready_env_ids): # j is the index in current ready_env_ids # i is the index in all environments if self.buffer is None: # users do not want to store data, so we store # small fake data here to make the code clean self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0) else: self._cached_buf[i].add(**self.data[j]) if done[j]: if not (isinstance(n_episode, list) and episode_count[i] >= n_episode[i]): episode_count[i] += 1 rewards.append(self._rew_metric( np.sum(self._cached_buf[i].rew, axis=0))) step_count += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) if isinstance(n_episode, list) and \ episode_count[i] >= n_episode[i]: # env i has collected enough data, it has finished finished_env_ids.append(i) self._cached_buf[i].reset() self._reset_state(j) obs_next = self.data.obs_next if sum(done): env_ind_local = np.where(done)[0] env_ind_global = self._ready_env_ids[env_ind_local] obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_reset = self.preprocess_fn( obs=obs_reset).get("obs", obs_reset) obs_next[env_ind_local] = obs_reset self.data.obs = obs_next if is_async: # set data back whole_data = deepcopy(whole_data) # avoid reference in ListBuf _batch_set_item( whole_data, self._ready_env_ids, self.data, self.env_num) # let self.data be the data in all environments again self.data = whole_data self._ready_env_ids = np.array( [x for x in self._ready_env_ids if x not in finished_env_ids]) if n_step: if step_count >= n_step: break else: if isinstance(n_episode, int) and \ episode_count.sum() >= n_episode: break if isinstance(n_episode, list) and \ (episode_count >= n_episode).all(): break # finished envs are ready, and can be used for the next collection self._ready_env_ids = np.array( self._ready_env_ids.tolist() + finished_env_ids) # generate the statistics episode_count = sum(episode_count) duration = max(time.time() - start_time, 1e-9) self.collect_step += step_count self.collect_episode += episode_count self.collect_time += duration return { "n/ep": episode_count, "n/st": step_count, "v/st": step_count / duration, "v/ep": episode_count / duration, "rew": np.mean(rewards), "rew_std": np.std(rewards), "len": step_count / episode_count, }
def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() assert not Batch(b={'c': {}}).is_empty() assert Batch(b={'c': {}}).is_empty(recurse=True) assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty() assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) assert not Batch(d=1).is_empty() assert not Batch(a=np.float64(1.0)).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() b = Batch() b.update() assert b.is_empty() b.update(c=[3, 5]) assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) assert b.a == 3 with pytest.raises(AssertionError): Batch({1: 2}) with pytest.raises(TypeError): Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.cat_(batch) assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) assert np.allclose(batch[0].obs, batch[1].obs) batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: assert b.obs == batch[i].obs else: with pytest.raises(AttributeError): batch[i].obs with pytest.raises(AttributeError): b.obs print(batch) batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) assert batch_item.a.b == batch_dict['b'] assert isinstance(batch_item.a.c, float) assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] batch2 = Batch(a=[{ 'b': np.float64(1.0), 'c': np.zeros(1), 'd': Batch(e=np.array(3.0)) }]) assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] assert batch2.shape[0] == 1 with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): batch2[1] assert batch2[0].shape == [] with pytest.raises(IndexError): batch2[0][0] with pytest.raises(TypeError): len(batch2[0]) assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.d.e, np.float64) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch([e for e in batch2]) assert batch2_from_list.a.b == batch2.a.b assert batch2_from_list.a.c == batch2.a.c assert batch2_from_list.a.d.e == batch2.a.d.e assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2_sum = (batch2 + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 batch3 = Batch(a={ 'c': np.zeros(1), 'd': Batch(e=np.array([0.0]), f=np.array([3.0])) }) batch3.a.d[0] = {'e': 4.0} assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 with pytest.raises(KeyError): batch3.a.d[0] = Batch(f=5.0, g=0.0) # auto convert batch4 = Batch(a=np.array(['a', 'b'])) assert batch4.a.dtype == np.object # auto convert to np.object batch4.update(a=np.array(['c', 'd'])) assert list(batch4.a) == ['c', 'd'] assert batch4.a.dtype == np.object # auto convert to np.object batch5 = Batch(a=np.array([{'index': 0}])) assert isinstance(batch5.a, Batch) assert np.allclose(batch5.a.index, [0]) batch5.b = np.array([{'index': 1}]) assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) # None is a valid object and can be stored in Batch a = Batch.stack([Batch(a=None), Batch(b=None)]) assert a.a[0] is None and a.a[1] is None assert a.b[0] is None and a.b[1] is None
class Collector(object): """The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to ``None`` (testing phase), it will not store the data. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults to ``None``. :param BaseNoise action_noise: add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose. :param function reward_metric: to be used in multi-agent RL. The reward to report is of shape [agent_num], but we need to return a single scalar to monitor training. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. By default, the behavior is to select the reward of agent 1. The ``preprocess_fn`` is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed in :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the collector resets the environment. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in "test/base/test_collector.py". Example: :: policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) collector = Collector(policy, envs, buffer=replay_buffer) # collect 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) Collected data always consist of full episodes. So if only ``n_step`` argument is give, the collector may return the data more than the ``n_step`` limitation. Same as ``n_episode`` for the multiple environment case. .. note:: Please make sure the given environment has a time limitation. """ def __init__( self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, action_noise: Optional[BaseNoise] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) self.env = env self.env_num = len(env) # environments that are available in step() # this means all environments in synchronous simulation # but only a subset of environments in asynchronous simulation self._ready_env_ids = np.arange(self.env_num) # self.async is a flag to indicate whether this collector works # with asynchronous simulation self.is_async = env.is_async # need cache buffers before storing in the main buffer self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn self.process_fn = policy.process_fn self._action_space = env.action_space self._action_noise = action_noise self._rew_metric = reward_metric or Collector._default_rew_metric # avoid creating attribute outside __init__ self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={}) self.reset() @staticmethod def _default_rew_metric(x): # this internal function is designed for single-agent RL # for multi-agent RL, a reward_metric must be provided assert np.asanyarray(x).size == 1, \ 'Please specify the reward_metric ' \ 'since the reward is not a scalar.' return x def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for ``state`` so that ``self.data`` supports slicing # convert empty Batch to None when passing data to policy self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={}) self.reset_env() self.reset_buffer() self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 if self._action_noise is not None: self._action_noise.reset() def reset_buffer(self) -> None: """Reset the main data buffer.""" if self.buffer is not None: self.buffer.reset() def get_env_num(self) -> int: """Return the number of environments the collector have.""" return self.env_num def reset_env(self) -> None: """Reset all of the environment(s)' states and reset all of the cache buffers (if need). """ self._ready_env_ids = np.arange(self.env_num) obs = self.env.reset() if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get('obs', obs) self.data.obs = obs for b in self._cached_buf: b.reset() def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: """Reset all the seed(s) of the given environment(s).""" return self.env.seed(seed) def render(self, **kwargs) -> None: """Render all the environment(s).""" return self.env.render(**kwargs) def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset the hidden state: self.data.state[id].""" state = self.data.state # it is a reference if isinstance(state, torch.Tensor): state[id].zero_() elif isinstance(state, np.ndarray): state[id] = None if state.dtype == np.object else 0 elif isinstance(state, Batch): state.empty_(id) def collect( self, n_step: Optional[int] = None, n_episode: Optional[Union[int, List[int]]] = None, random: bool = False, render: Optional[float] = None, ) -> Dict[str, float]: """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect. If it is an int, it means to collect at lease ``n_episode`` episodes; if it is a list, it means to collect exactly ``n_episode[i]`` episodes in the i-th environment :param bool random: whether to use random policy for collecting data, defaults to ``False``. :param float render: the sleep time between rendering consecutive frames, defaults to ``None`` (no rendering). .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. * ``v/st`` the speed of steps per second. * ``v/ep`` the speed of episode per second. * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ assert (n_step and not n_episode) or (not n_step and n_episode), \ "One and only one collection number specification is permitted!" start_time = time.time() step_count = 0 # episode of each environment episode_count = np.zeros(self.env_num) reward_total = 0.0 whole_data = Batch() while True: if step_count >= 100000 and episode_count.sum() == 0: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) if self.is_async: # self.data are the data for all environments # in async simulation, only a subset of data are disposed # so we store the whole data in ``whole_data``, let self.data # to be all the data available in ready environments, and # finally set these back into all the data whole_data = self.data self.data = self.data[self._ready_env_ids] # restore the state and the input data last_state = self.data.state if isinstance(last_state, Batch) and last_state.is_empty(): last_state = None self.data.update(state=Batch(), obs_next=Batch(), policy=Batch()) # calculate the next action if random: spaces = self._action_space result = Batch( act=[spaces[i].sample() for i in self._ready_env_ids]) else: with torch.no_grad(): result = self.policy(self.data, last_state) state = result.get('state', Batch()) # convert None to Batch(), since None is reserved for 0-init if state is None: state = Batch() self.data.update(state=state, policy=result.get('policy', Batch())) # save hidden state to policy._state, in order to save into buffer if not (isinstance(self.data.state, Batch) and self.data.state.is_empty()): self.data.policy._state = self.data.state self.data.act = to_numpy(result.act) if self._action_noise is not None: self.data.act += self._action_noise(self.data.act.shape) # step in env if not self.is_async: obs_next, rew, done, info = self.env.step(self.data.act) else: # store computed actions, states, etc _batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num) # fetch finished data obs_next, rew, done, info = self.env.step( action=self.data.act, id=self._ready_env_ids) self._ready_env_ids = np.array([i['env_id'] for i in info]) # get the stepped data self.data = whole_data[self._ready_env_ids] # move data to self.data self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if render: self.render() time.sleep(render) # add data into the buffer if self.preprocess_fn: result = self.preprocess_fn(**self.data) self.data.update(result) for j, i in enumerate(self._ready_env_ids): # j is the index in current ready_env_ids # i is the index in all environments self._cached_buf[i].add(**self.data[j]) if self.data.done[j]: if n_step or np.isscalar(n_episode) or \ episode_count[i] < n_episode[i]: episode_count[i] += 1 reward_total += np.sum(self._cached_buf[i].rew, axis=0) step_count += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) self._cached_buf[i].reset() self._reset_state(j) obs_next = self.data.obs_next if sum(self.data.done): env_ind_local = np.where(self.data.done)[0] env_ind_global = self._ready_env_ids[env_ind_local] obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_next[env_ind_local] = self.preprocess_fn( obs=obs_reset).get('obs', obs_reset) else: obs_next[env_ind_local] = obs_reset self.data.obs = obs_next if self.is_async: # set data back _batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num) # let self.data be the data in all environments again self.data = whole_data if n_step: if step_count >= n_step: break else: if isinstance(n_episode, int) and \ episode_count.sum() >= n_episode: break if isinstance(n_episode, list) and \ (episode_count >= n_episode).all(): break # generate the statistics episode_count = sum(episode_count) duration = max(time.time() - start_time, 1e-9) self.collect_step += step_count self.collect_episode += episode_count self.collect_time += duration # average reward across the number of episodes reward_avg = reward_total / episode_count if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg reward_avg = self._rew_metric(reward_avg) return { 'n/ep': episode_count, 'n/st': step_count, 'v/st': step_count / duration, 'v/ep': episode_count / duration, 'rew': reward_avg, 'len': step_count / episode_count, } def sample(self, batch_size: int) -> Batch: """Sample a data batch from the internal replay buffer. It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size. """ warnings.warn( 'Collector.sample is deprecated and will cause error if you use ' 'prioritized experience replay! Collector.sample will be removed ' 'upon version 0.3. Use policy.update instead!', Warning) batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data def close(self) -> None: warnings.warn( 'Collector.close is deprecated and will be removed upon version ' '0.3.', Warning)
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: batch = super().process_fn(batch, buffer, indice) step = batch.step done_cnt = batch.done_cnt rel_step = step / np.max(step) if step.any() else np.zeros_like(step) if self.bk_step: # convert bk step to forward rel_step = 1 - rel_step if self.reweigh_type == "hard": med = np.median(step) cond = step > med if self.bk_step else step < med weight = np.where(cond, self.tper_weight, 2 - self.tper_weight) elif self.reweigh_type == "linear": weight = self._calc_linear_weight(rel_step, self.l, self.h, self.k, self.b) elif self.reweigh_type == 'adaptive_linear': cur_low = np.clip( self.low_l + (self.low_h - self.low_l) / (self.t_e - self.t_s) * (self._iter - self.t_s), self.low_l, self.low_h) cur_high = np.clip( self.high_h + (self.high_l - self.high_h) / (self.t_e - self.t_s) * (self._iter - self.t_s), self.high_l, self.high_h) weight = self._calc_linear_weight(rel_step, cur_low, cur_high, self.k, self.b) elif self.reweigh_type == 'done_cnt_linear': rel_done_cnt = done_cnt / np.max(done_cnt) # The tajectory is newer with larger done counts, which can be understood as fewer learning steps pseudo_step = 1 - rel_done_cnt cur_low = np.clip( self.low_l + (self.low_h - self.low_l) * pseudo_step, self.low_l, self.low_h) cur_high = np.clip( self.high_h + (self.high_l - self.high_h) * pseudo_step, self.high_l, self.high_h) weight = self._calc_linear_weight(rel_step, cur_low, cur_high, self.k, self.b) elif self.reweigh_type == 'oracle': info = batch.info # assert "agent_pos" in info.keys() agent_pos = info["agent_pos"] reward = batch.rew done = batch.done action = batch.act # print(obs_next, reward, done) next_agent_pos = self._get_next_agent_pos(agent_pos, action) next_V = self._get_oracle_V(next_agent_pos) Qstar = reward + (1 - done) * next_V with torch.no_grad(): Qs = self.forward(batch).logits.detach().cpu().numpy() Qk = [] for Q, a in zip(Qs, action): Qk.append(Q[a]) weight = np.exp(-np.abs(Qk - Qstar)) assert weight.shape[0] == rel_step.shape[0] weight = weight / np.sum(weight) * rel_step.shape[0] batch.update({"weight": weight}) return batch
class adversarial_training_collector(object): """Collector that defends an existing policy with adversarial training. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param obs_adv_atk: an instance of the :class:`~advertorch.attacks.base.Attack` class implementing an image adversarial attack. :param atk_frequency: float, how frequently attacking env observations :param test: bool, if True adversarial actions replace original actions :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to ``None`` (testing phase), it will not store the data. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults to None. :param function reward_metric: to be used in multi-agent RL. The reward to report is of shape [agent_num], but we need to return a single scalar to monitor training. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. By default, the behavior is to select the reward of agent 1. :param atk_frequency: float, how frequently attacking env observations. Note: parallel or async envs are currently not supported """ def __init__( self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], obs_adv_atk: Attack, buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, atk_frequency: float = 0.5, test: bool = False, device: str = 'cuda' if torch.cuda.is_available() else 'cpu' ) -> None: super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) self.env = env self.env_num = len(env) self.device = device self.obs_adv_atk = obs_adv_atk self.obs_adv_atk.targeted = False self.atk_frequency = atk_frequency self.test = test # environments that are available in step() # this means all environments in synchronous simulation # but only a subset of environments in asynchronous simulation self._ready_env_ids = np.arange(self.env_num) # need cache buffers before storing in the main buffer self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn self.process_fn = policy.process_fn self._action_space = env.action_space self._rew_metric = reward_metric or adversarial_training_collector._default_rew_metric # avoid creating attribute outside __init__ self.reset() @staticmethod def _default_rew_metric( x: Union[Number, np.number]) -> Union[Number, np.number]: # this internal function is designed for single-agent RL # for multi-agent RL, a reward_metric must be provided assert np.asanyarray(x).size == 1, ( "Please specify the reward_metric " "since the reward is not a scalar.") return x def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for ``state`` so that ``self.data`` supports slicing # convert empty Batch to None when passing data to policy self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={}) self.reset_env() self.reset_buffer() self.reset_stat() def reset_stat(self) -> None: """Reset the statistic variables.""" self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0 def reset_buffer(self) -> None: """Reset the main data buffer.""" if self.buffer is not None: self.buffer.reset() def get_env_num(self) -> int: """Return the number of environments the collector have.""" return self.env_num def reset_env(self) -> None: """Reset all of the environment(s)' states and the cache buffers.""" self._ready_env_ids = np.arange(self.env_num) obs = self.env.reset() if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) self.data.obs = obs for b in self._cached_buf: b.reset() def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset the hidden state: self.data.state[id].""" state = self.data.state # it is a reference if isinstance(state, torch.Tensor): state[id].zero_() elif isinstance(state, np.ndarray): state[id] = None if state.dtype == np.object else 0 elif isinstance(state, Batch): state.empty_(id) def collect( self, n_step: Optional[int] = None, n_episode: Optional[Union[int, List[int]]] = None, random: bool = False, render: Optional[float] = None, no_grad: bool = True, ) -> Dict[str, float]: """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect. If it is an int, it means to collect at lease ``n_episode`` episodes; if it is a list, it means to collect exactly ``n_episode[i]`` episodes in the i-th environment :param bool random: whether to use random policy for collecting data, defaults to False. :param float render: the sleep time between rendering consecutive frames, defaults to None (no rendering). :param bool no_grad: whether to retain gradient in policy.forward, defaults to True (no gradient retaining). .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. * ``v/st`` the speed of steps per second. * ``v/ep`` the speed of episode per second. * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ assert (n_step is not None and n_episode is None and n_step > 0) or ( n_step is None and n_episode is not None and np.sum(n_episode) > 0 ), "Only one of n_step or n_episode is allowed in Collector.collect, " f"got n_step = {n_step}, n_episode = {n_episode}." start_time = time.time() step_count = 0 succ_attacks = 0 n_attacks = 0 # episode of each environment episode_count = np.zeros(self.env_num) # If n_episode is a list, and some envs have collected the required # number of episodes, these envs will be recorded in this list, and # they will not be stepped. finished_env_ids = [] rewards = [] if isinstance(n_episode, list): assert len(n_episode) == self.get_env_num() finished_env_ids = [ i for i in self._ready_env_ids if n_episode[i] <= 0 ] self._ready_env_ids = np.array( [x for x in self._ready_env_ids if x not in finished_env_ids]) while True: if step_count >= 100000 and episode_count.sum() == 0: warnings.warn( "There are already many steps in an episode. " "You should add a time limitation to your environment!", Warning) # restore the state and the input data last_state = self.data.state if isinstance(last_state, Batch) and last_state.is_empty(): last_state = None self.data.update(state=Batch(), obs_next=Batch(), policy=Batch()) # calculate the next action if random: spaces = self._action_space result = Batch( act=[spaces[i].sample() for i in self._ready_env_ids]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version result = self.policy(self.data, last_state) else: result = self.policy(self.data, last_state) state = result.get("state", Batch()) # convert None to Batch(), since None is reserved for 0-init if state is None: state = Batch() self.data.update(state=state, policy=result.get("policy", Batch())) # save hidden state to policy._state, in order to save into buffer if not (isinstance(state, Batch) and state.is_empty()): self.data.policy._state = self.data.state self.data.act = to_numpy(result.act) # START ADVERSARIAL ATTACK x = rd.uniform(0, 1) if x < self.atk_frequency: ori_act = self.data.act adv_act, adv_obs = self.obs_attacks(self.data, ori_act) for j, i in enumerate(self._ready_env_ids): if adv_act[i] != ori_act[i]: succ_attacks += 1 n_attacks += self.env_num self.data.update( obs=adv_obs ) # so that the adv obs will be inserted in the buffer if self.test: self.data.act = adv_act # step in env obs_next, rew, done, info = self.env.step(self.data.act) # move data to self.data self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if render: self.env.render() time.sleep(render) # add data into the buffer if self.preprocess_fn: result = self.preprocess_fn(**self.data) # type: ignore self.data.update(result) for j, i in enumerate(self._ready_env_ids): # j is the index in current ready_env_ids # i is the index in all environments if self.buffer is None: # users do not want to store data, so we store # small fake data here to make the code clean self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0) else: self._cached_buf[i].add(**self.data[j]) if done[j]: if not (isinstance(n_episode, list) and episode_count[i] >= n_episode[i]): episode_count[i] += 1 rewards.append( self._rew_metric( np.sum(self._cached_buf[i].rew, axis=0))) step_count += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) if isinstance(n_episode, list) and \ episode_count[i] >= n_episode[i]: # env i has collected enough data, it has finished finished_env_ids.append(i) self._cached_buf[i].reset() self._reset_state(j) obs_next = self.data.obs_next if sum(done): env_ind_local = np.where(done)[0] env_ind_global = self._ready_env_ids[env_ind_local] obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_reset = self.preprocess_fn(obs=obs_reset).get( "obs", obs_reset) obs_next[env_ind_local] = obs_reset self.data.obs = obs_next self._ready_env_ids = np.array( [x for x in self._ready_env_ids if x not in finished_env_ids]) if n_step: if step_count >= n_step: break else: if isinstance(n_episode, int) and \ episode_count.sum() >= n_episode: break if isinstance(n_episode, list) and \ (episode_count >= n_episode).all(): break # finished envs are ready, and can be used for the next collection self._ready_env_ids = np.array(self._ready_env_ids.tolist() + finished_env_ids) # generate the statistics episode_count = sum(episode_count) duration = max(time.time() - start_time, 1e-9) self.collect_step += step_count self.collect_episode += episode_count self.collect_time += duration return { "n/ep": episode_count, "n/st": step_count, "v/st": step_count / duration, "v/ep": episode_count / duration, "rew": np.mean(rewards), "rew_std": np.std(rewards), "len": step_count / episode_count, 'succ_atks(%)': succ_attacks / n_attacks if n_attacks > 0 else 0, } def obs_attacks(self, data, target_action: List[int]): """ Performs an image adversarial attack on the observation stored in 'obs' respect to the action 'target_action' using the method defined in 'self.obs_adv_atk' """ data = deepcopy(data) obs = torch.FloatTensor(data.obs).to( self.device) # convert observation to tensor act = torch.tensor(target_action).to( self.device) # convert action to tensor adv_obs = self.obs_adv_atk.perturb( obs, act) # create adversarial observation with torch.no_grad(): adv_obs = adv_obs.cpu().detach().numpy() data.obs = adv_obs result = self.policy(data, last_state=None) return to_numpy(result.act), adv_obs