Exemplo n.º 1
def test_batch():
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    assert batch.obs == [1]
    assert batch.obs == [1, 1]
    assert batch.np.shape == (6, 4)
    assert batch[0].obs == batch[1].obs
    with pytest.raises(IndexError):
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, permute=False)):
        assert b.obs == batch[i].obs
Exemplo n.º 2
def test_batch_over_batch():
    batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
    batch2 = Batch({'c': [6, 7, 8], 'b': batch})
    batch2.b.b[-1] = 0
    for k, v in batch2.items():
        assert np.all(batch2[k] == v)
    assert batch2[-1].b.b == 0
    batch2.cat_(Batch(c=[6, 7, 8], b=batch))
    assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
    assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
    assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
    batch2.update(batch2.b, six=[6, 6, 6])
    assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
    assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5])
    assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0])
    assert np.allclose(batch2.six, [6, 6, 6])
    d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
    batch3 = Batch(c=[6, 7, 8], b=d)
    batch3.cat_(Batch(c=[6, 7, 8], b=d))
    assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8])
    assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5])
    assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6])
    batch4 = Batch(({'a': {'b': np.array([1.0])}},))
    assert batch4.a.b.ndim == 2
    assert batch4.a.b[0, 0] == 1.0
    # advanced slicing
    batch5 = Batch(a=[[1, 2]], b={'c': np.zeros([3, 2, 1])})
    assert batch5.shape == [1, 2]
    with pytest.raises(IndexError):
    with pytest.raises(IndexError):
        batch5[:, 3]
    with pytest.raises(IndexError):
        batch5[:, :, -1]
    batch5[:, -1] += 1
    assert np.allclose(batch5.a, [1, 3])
    assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
    with pytest.raises(ValueError):
        batch5[:, -1] = 1
    batch5[:, 0] = {'a': -1}
    assert np.allclose(batch5.a, [-1, 3])
    assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
Exemplo n.º 3
class base_attack_collector:
    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
    :param env: a ``gym.Env`` environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param obs_adv_atk: an instance of the :class:`~advertorch.attacks.base.Attack`
        class implementing an image adversarial attack.
    :param perfect_attack: force adversarial attacks on observations to be
        always effective (ignore the ``adv`` param).
    def __init__(self,
                 policy: BasePolicy,
                 env: gym.Env,
                 obs_adv_atk: Attack,
                 perfect_attack: bool = False,
                 device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.policy = policy
        self.env = env
        self.obs_adv_atk = obs_adv_atk
        self.perfect_attack = perfect_attack
        self.action_space = self.env.action_space.shape or self.env.action_space.n
        self.data = Batch(state={},
        self.episode_count = 0  # current number of episodes
        self.reward_total = 0.  # total episode cumulative reward
        self.frames_count = 0  # number of observed frames
        self.n_attacks = 0  # number of attacks performed
        self.succ_attacks = 0  # number of successful image attacks
        self.start_time = 0  # time when the attack starts

    def reset_env(self):
        self.data.obs = self.env.reset()

    def render(self, **kwargs) -> None:
        return self.env.render(**kwargs)

    def reset_attack(self):
        self.episode_count, self.reward_total, self.frames_count,\
            self.n_attacks, self.succ_attacks = 0, 0, 0, 0, 0
        self.start_time = time.time()

    def get_attack_stats(self) -> Dict[str, float]:
        duration = max(time.time() - self.start_time, 1e-9)
        if self.episode_count == 0:
            self.episode_count = 1
        return {
            self.frames_count / duration,
            self.episode_count / duration,
            self.reward_total / self.episode_count,
            self.frames_count / self.episode_count,
            self.n_attacks / self.episode_count,
            self.succ_attacks / self.episode_count,
            self.n_attacks / self.frames_count,
            self.succ_attacks / self.n_attacks if self.n_attacks > 0 else 0,

    def show_warning(self):
        if self.frames_count >= 100000 and self.episode_count == 0:
                'There are already many steps in an episode. '
                'You should add a time limitation to your environment!',

    def check_end_attack(self, n_step, n_episode) -> bool:
        """Returns True when the attack terminates"""
        if n_step:
            if self.frames_count >= n_step:
                return True
        if n_episode:
            if self.episode_count >= n_episode:
                return True
        return False

    def perform_step(self):
        Performs action 'self.data.act' on 'self.env' and store the next observation in 'self.data.obs'
        obs_next, rew, done, info = self.env.step(self.data.act[0])
        self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
        self.reward_total += rew
        if self.data.done:
            self.episode_count += 1
        self.data.obs = self.data.obs_next

    def predict_next_action(self):
        Predicts the next action given observation 'self.data.obs' and policy 'self.policy',
        and stores it in 'self.data.act'
        :return: outcome of policy forward pass
        with torch.no_grad():
            self.data.obs = np.expand_dims(self.data.obs, axis=0)
            result = self.policy(self.data, last_state=None)
        self.data.act = to_numpy(result.act)
        return result

    def obs_attacks(
        target_action: Optional[List[int]] = None,
        Performs an image adversarial attack on the observation stored in 'self.data.obs' respect to
        the action 'target_action' using the method defined in 'self.obs_adv_atk'
        :param target_action:
                - if obs_adv_atk.targeted=False, then 'target_action' must be the normal action.
                - if obs_adv_atk.targeted=True, then 'target_action' must be the adversarial action.
        if not target_action:
            target_action = self.data.act
        obs = torch.FloatTensor(self.data.obs).to(
            self.device)  # convert observation to tensor
        act = torch.tensor(target_action).to(
            self.device)  # convert action to tensor
        adv_obs = self.obs_adv_atk.perturb(
            obs, act)  # create adversarial observation
        with torch.no_grad():
            data = copy.deepcopy(self.data)
            data.obs = adv_obs.cpu().detach().numpy()
            result = self.policy(data, last_state=None)
        self.data.act = to_numpy(result.act)

    def collect(self,
                n_step: int = 0,
                n_episode: int = 0,
                render: Optional[float] = None) -> Dict[str, float]:
        :param int n_step: how many steps you want to collect.
        :param n_episode: how many episodes you want to collect.
        :param float render: the sleep time between rendering consecutive
            frames, defaults to ``None`` (no rendering).
        :return: A dict including the following keys
            * ``n/ep`` the collected number of episodes.
            * ``n/st`` the collected number of steps.
            * ``v/st`` the speed of steps per second.
            * ``v/ep`` the speed of episode per second.
            * ``rew`` the mean reward over collected episodes.
            * ``len`` the mean length over collected episodes.
            * ``n_attacks`` number of performed attacks.
            * ``n_succ_attacks`` number of performed successful attacks.
            * ``n_attacks(%)`` ratio of performed attacks over steps.
            * ``succ_atks(%)`` ratio of successful attacks over performed attacks.

        error = "Sub-classes must implement 'collect'."
        raise NotImplementedError(error)
Exemplo n.º 4
class Collector(object):
    """Collector enables the policy to interact with different types of envs with \
    exact number of steps or episodes.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
    :param env: a ``gym.Env`` environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
        If set to None, it will not store the data. Default to None.
    :param function preprocess_fn: a function called before the data has been added to
        the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None.
    :param bool exploration_noise: determine whether the action needs to be modified
        with corresponding policy's exploration noise. If so, "policy.
        exploration_noise(act, batch)" will be called automatically to add the
        exploration noise into action. Default to False.

    The "preprocess_fn" is a function called before the data has been added to the
    buffer with batch format. It will receive only "obs" and "env_id" when the
    collector resets the environment, and will receive six keys "obs_next", "rew",
    "done", "info", "policy" and "env_id" in a normal env step. It returns either a
    dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples
    are in "test/base/test_collector.py".

    .. note::

        Please make sure the given environment has a time limitation if using n_episode
        collect option.
    def __init__(
        policy: BasePolicy,
        env: Union[gym.Env, BaseVectorEnv],
        buffer: Optional[ReplayBuffer] = None,
        preprocess_fn: Optional[Callable[..., Batch]] = None,
        exploration_noise: bool = False,
    ) -> None:
        if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
                "Single environment detected, wrap to DummyVectorEnv.")
            env = DummyVectorEnv([lambda: env])
        self.env = env
        self.env_num = len(env)
        self.exploration_noise = exploration_noise
        self.policy = policy
        self.preprocess_fn = preprocess_fn
        self._action_space = env.action_space
        # avoid creating attribute outside __init__

    def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None:
        """Check if the buffer matches the constraint."""
        if buffer is None:
            buffer = VectorReplayBuffer(self.env_num, self.env_num)
        elif isinstance(buffer, ReplayBufferManager):
            assert buffer.buffer_num >= self.env_num
            if isinstance(buffer, CachedReplayBuffer):
                assert buffer.cached_buffer_num >= self.env_num
        else:  # ReplayBuffer or PrioritizedReplayBuffer
            assert buffer.maxsize > 0
            if self.env_num > 1:
                if type(buffer) == ReplayBuffer:
                    buffer_type = "ReplayBuffer"
                    vector_type = "VectorReplayBuffer"
                    buffer_type = "PrioritizedReplayBuffer"
                    vector_type = "PrioritizedVectorReplayBuffer"
                raise TypeError(
                    f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect "
                    f"{self.env_num} envs,\n\tplease use {vector_type}(total_size="
                    f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead."
        self.buffer = buffer

    def reset(self) -> None:
        """Reset all related variables in the collector."""
        # use empty Batch for "state" so that self.data supports slicing
        # convert empty Batch to None when passing data to policy
        self.data = Batch(obs={},

    def reset_stat(self) -> None:
        """Reset the statistic variables."""
        self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0

    def reset_buffer(self, keep_statistics: bool = False) -> None:
        """Reset the data buffer."""

    def reset_env(self) -> None:
        """Reset all of the environments."""
        obs = self.env.reset()
        if self.preprocess_fn:
            obs = self.preprocess_fn(obs=obs,
                                         "obs", obs)
        self.data.obs = obs

    def _reset_state(self, id: Union[int, List[int]]) -> None:
        """Reset the hidden state: self.data.state[id]."""
        if hasattr(self.data.policy, "hidden_state"):
            state = self.data.policy.hidden_state  # it is a reference
            if isinstance(state, torch.Tensor):
            elif isinstance(state, np.ndarray):
                state[id] = None if state.dtype == object else 0
            elif isinstance(state, Batch):

    def collect(
        n_step: Optional[int] = None,
        n_episode: Optional[int] = None,
        random: bool = False,
        render: Optional[float] = None,
        no_grad: bool = True,
    ) -> Dict[str, Any]:
        """Collect a specified number of step or episode.

        To ensure unbiased sampling result with n_episode option, this function will
        first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
        episodes, they will be collected evenly from each env.

        :param int n_step: how many steps you want to collect.
        :param int n_episode: how many episodes you want to collect.
        :param bool random: whether to use random policy for collecting data. Default
            to False.
        :param float render: the sleep time between rendering consecutive frames.
            Default to None (no rendering).
        :param bool no_grad: whether to retain gradient in policy.forward(). Default to
            True (no gradient retaining).

        .. note::

            One and only one collection number specification is permitted, either
            ``n_step`` or ``n_episode``.

        :return: A dict including the following keys

            * ``n/ep`` collected number of episodes.
            * ``n/st`` collected number of steps.
            * ``rews`` array of episode reward over collected episodes.
            * ``lens`` array of episode length over collected episodes.
            * ``idxs`` array of episode start index in buffer over collected episodes.
        assert not self.env.is_async, "Please use AsyncCollector if using async venv."
        if n_step is not None:
            assert n_episode is None, (
                f"Only one of n_step or n_episode is allowed in Collector."
                f"collect, got n_step={n_step}, n_episode={n_episode}.")
            assert n_step > 0
            if not n_step % self.env_num == 0:
                    f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
                    "which may cause extra transitions collected into the buffer."
            ready_env_ids = np.arange(self.env_num)
        elif n_episode is not None:
            assert n_episode > 0
            ready_env_ids = np.arange(min(self.env_num, n_episode))
            self.data = self.data[:min(self.env_num, n_episode)]
            raise TypeError(
                "Please specify at least one (either n_step or n_episode) "
                "in AsyncCollector.collect().")

        start_time = time.time()

        step_count = 0
        episode_count = 0
        episode_rews = []
        episode_lens = []
        episode_start_indices = []

        while True:
            assert len(self.data) == len(ready_env_ids)
            # restore the state: if the last state is None, it won't store
            last_state = self.data.policy.pop("hidden_state", None)

            # get the next action
            if random:
                    self._action_space[i].sample() for i in ready_env_ids
                if no_grad:
                    with torch.no_grad():  # faster than retain_grad version
                        # self.data.obs will be used by agent to get result
                        result = self.policy(self.data, last_state)
                    result = self.policy(self.data, last_state)
                # update state / act / policy into self.data
                policy = result.get("policy", Batch())
                assert isinstance(policy, Batch)
                state = result.get("state", None)
                if state is not None:
                    policy.hidden_state = state  # save state into buffer
                act = to_numpy(result.act)
                if self.exploration_noise:
                    act = self.policy.exploration_noise(act, self.data)
                self.data.update(policy=policy, act=act)

            # get bounded and remapped actions first (not saved into buffer)
            action_remap = self.policy.map_action(self.data.act)
            # step in env
            result = self.env.step(action_remap, ready_env_ids)  # type: ignore
            obs_next, rew, done, info = result

            self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
            if self.preprocess_fn:

            if render:
                if render > 0 and not np.isclose(render, 0):

            # add data into the buffer
            ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
                self.data, buffer_ids=ready_env_ids)

            # collect statistics
            step_count += len(ready_env_ids)

            if np.any(done):
                env_ind_local = np.where(done)[0]
                env_ind_global = ready_env_ids[env_ind_local]
                episode_count += len(env_ind_local)
                # now we copy obs_next to obs, but since there might be
                # finished episodes, we have to reset finished envs first.
                obs_reset = self.env.reset(env_ind_global)
                if self.preprocess_fn:
                    obs_reset = self.preprocess_fn(obs=obs_reset,
                                                       "obs", obs_reset)
                self.data.obs_next[env_ind_local] = obs_reset
                for i in env_ind_local:

                # remove surplus env id from ready_env_ids
                # to avoid bias in selecting environments
                if n_episode:
                    surplus_env_num = len(ready_env_ids) - (n_episode -
                    if surplus_env_num > 0:
                        mask = np.ones_like(ready_env_ids, dtype=bool)
                        mask[env_ind_local[:surplus_env_num]] = False
                        ready_env_ids = ready_env_ids[mask]
                        self.data = self.data[mask]

            self.data.obs = self.data.obs_next

            if (n_step and step_count >= n_step) or \
                    (n_episode and episode_count >= n_episode):

        # generate statistics
        self.collect_step += step_count
        self.collect_episode += episode_count
        self.collect_time += max(time.time() - start_time, 1e-9)

        if n_episode:
            self.data = Batch(obs={},

        if episode_count > 0:
            rews, lens, idxs = list(
                    [episode_rews, episode_lens, episode_start_indices]))
            rews, lens, idxs = np.array([]), np.array([], int), np.array([],

        return {
            "n/ep": episode_count,
            "n/st": step_count,
            "rews": rews,
            "lens": lens,
            "idxs": idxs,
Exemplo n.º 5
def test_batch():
    assert list(Batch()) == []
    assert Batch().is_empty()
    assert not Batch(b={'c': {}}).is_empty()
    assert Batch(b={'c': {}}).is_empty(recurse=True)
    assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
    assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
    assert not Batch(d=1).is_empty()
    assert not Batch(a=np.float64(1.0)).is_empty()
    assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
    assert not Batch(a=[1, 2, 3]).is_empty()
    b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None])
    assert b.c.dtype == object
    b = Batch(d=[None], e=[starmap], f=Batch)
    assert b.d.dtype == b.e.dtype == object and b.f == Batch
    b = Batch()
    assert b.is_empty()
    b.update(c=[3, 5])
    assert np.allclose(b.c, [3, 5])
    # mimic the behavior of dict.update, where kwargs can overwrite keys
    b.update({'a': 2}, a=3)
    assert 'a' in b and b.a == 3
    assert b.pop('a') == 3
    assert 'a' not in b
    with pytest.raises(AssertionError):
        Batch({1: 2})
    assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
    with pytest.raises(TypeError):
        Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
    batch = Batch(a=[torch.ones(3), torch.ones(3)])
    assert torch.allclose(batch.a, torch.ones(2, 3))
    assert torch.allclose(batch.a, torch.ones(4, 3))
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    assert batch.obs == batch["obs"]
    batch.obs = [1]
    assert batch.obs == [1]
    assert np.allclose(batch.obs, [1, 1])
    assert batch.np.shape == (6, 4)
    assert np.allclose(batch[0].obs, batch[1].obs)
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, shuffle=False)):
        if i != 5:
            assert b.obs == batch[i].obs
            with pytest.raises(AttributeError):
            with pytest.raises(AttributeError):
    batch = Batch(a=np.arange(10))
    with pytest.raises(AssertionError):
    data = [
        (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
        (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
        (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]),
        (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]),
        (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
        (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
        (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]),
        (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
    for size, merge_last, result in data:
        bs = list(batch.split(size, shuffle=False, merge_last=merge_last))
        assert [bs[i].a.tolist() for i in range(len(bs))] == result
    batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
    batch_item = Batch({'a': [batch_dict]})[0]
    assert isinstance(batch_item.a.b, np.ndarray)
    assert batch_item.a.b == batch_dict['b']
    assert isinstance(batch_item.a.c, float)
    assert batch_item.a.c == batch_dict['c']
    assert isinstance(batch_item.a.d, torch.Tensor)
    assert batch_item.a.d == batch_dict['d']
    batch2 = Batch(a=[{
        'b': np.float64(1.0),
        'c': np.zeros(1),
        'd': Batch(e=np.array(3.0))}])
    assert len(batch2) == 1
    assert Batch().shape == []
    assert Batch(a=1).shape == []
    assert Batch(a=set((1, 2, 1))).shape == []
    assert batch2.shape[0] == 1
    assert 'a' in batch2 and all([i in batch2.a for i in 'bcd'])
    with pytest.raises(IndexError):
    with pytest.raises(IndexError):
    assert batch2[0].shape == []
    with pytest.raises(IndexError):
    with pytest.raises(TypeError):
    assert isinstance(batch2[0].a.c, np.ndarray)
    assert isinstance(batch2[0].a.b, np.float64)
    assert isinstance(batch2[0].a.d.e, np.float64)
    batch2_from_list = Batch(list(batch2))
    batch2_from_comp = Batch([e for e in batch2])
    assert batch2_from_list.a.b == batch2.a.b
    assert batch2_from_list.a.c == batch2.a.c
    assert batch2_from_list.a.d.e == batch2.a.d.e
    assert batch2_from_comp.a.b == batch2.a.b
    assert batch2_from_comp.a.c == batch2.a.c
    assert batch2_from_comp.a.d.e == batch2.a.d.e
    for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
        assert batch_slice.a.b == batch2.a.b
        assert batch_slice.a.c == batch2.a.c
        assert batch_slice.a.d.e == batch2.a.d.e
    batch2.a.d.f = {}
    batch2_sum = (batch2 + 1.0) * 2
    assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
    assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
    assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
    assert batch2_sum.a.d.f.is_empty()
    with pytest.raises(TypeError):
        batch2 += [1]
    batch3 = Batch(a={
        'c': np.zeros(1),
        'd': Batch(e=np.array([0.0]), f=np.array([3.0]))})
    batch3.a.d[0] = {'e': 4.0}
    assert batch3.a.d.e[0] == 4.0
    batch3.a.d[0] = Batch(f=5.0)
    assert batch3.a.d.f[0] == 5.0
    with pytest.raises(ValueError):
        batch3.a.d[0] = Batch(f=5.0, g=0.0)
    with pytest.raises(ValueError):
        batch3[0] = Batch(a={"c": 2, "e": 1})
    # auto convert
    batch4 = Batch(a=np.array(['a', 'b']))
    assert batch4.a.dtype == object  # auto convert to object
    batch4.update(a=np.array(['c', 'd']))
    assert list(batch4.a) == ['c', 'd']
    assert batch4.a.dtype == object  # auto convert to object
    batch5 = Batch(a=np.array([{'index': 0}]))
    assert isinstance(batch5.a, Batch)
    assert np.allclose(batch5.a.index, [0])
    batch5.b = np.array([{'index': 1}])
    assert isinstance(batch5.b, Batch)
    assert np.allclose(batch5.b.index, [1])

    # None is a valid object and can be stored in Batch
    a = Batch.stack([Batch(a=None), Batch(b=None)])
    assert a.a[0] is None and a.a[1] is None
    assert a.b[0] is None and a.b[1] is None

    # nx.Graph corner case
    assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object
    g1 = nx.Graph()
    g2 = nx.Graph()
    assert Batch(a=np.array([g1, g2])).a.dtype == object
class Collector(object):
    """The :class:`~tianshou.data.Collector` enables the policy to interact
    with different types of environments conveniently.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
    :param env: a ``gym.Env`` environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
        class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
        ``None``, it will automatically assign a small-size
    :param function preprocess_fn: a function called before the data has been
        added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
        to ``None``.
    :param int stat_size: for the moving average of recording speed, defaults
        to 100.
    :param BaseNoise action_noise: add a noise to continuous action. Normally
        a policy already has a noise param for exploration in training phase,
        so this is recommended to use in test collector for some purpose.
    :param function reward_metric: to be used in multi-agent RL. The reward to
        report is of shape [agent_num], but we need to return a single scalar
        to monitor training. This function specifies what is the desired
        metric, e.g., the reward of agent 1 or the average reward over all
        agents. By default, the behavior is to select the reward of agent 1.

    The ``preprocess_fn`` is a function called before the data has been added
    to the buffer with batch format, which receives up to 7 keys as listed in
    :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the
    collector resets the environment. It returns either a dict or a
    :class:`~tianshou.data.Batch` with the modified keys and values. Examples
    are in "test/base/test_collector.py".


        policy = PGPolicy(...)  # or other policies if you wish
        env = gym.make('CartPole-v0')
        replay_buffer = ReplayBuffer(size=10000)
        # here we set up a collector with a single environment
        collector = Collector(policy, env, buffer=replay_buffer)

        # the collector supports vectorized environments as well
        envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
        buffers = [ReplayBuffer(size=5000) for _ in range(3)]
        # you can also pass a list of replay buffer to collector, for multi-env
        # collector = Collector(policy, envs, buffer=buffers)
        collector = Collector(policy, envs, buffer=replay_buffer)

        # collect at least 3 episodes
        # collect 1 episode for the first env, 3 for the third env
        collector.collect(n_episode=[1, 0, 3])
        # collect at least 2 steps
        # collect episodes with visual rendering (the render argument is the
        #   sleep time between rendering consecutive frames)
        collector.collect(n_episode=1, render=0.03)

        # sample data with a given number of batch-size:
        batch_data = collector.sample(batch_size=64)
        # policy.learn(batch_data)  # btw, vanilla policy gradient only
        #   supports on-policy training, so here we pick all data in the buffer
        batch_data = collector.sample(batch_size=0)
        # on-policy algorithms use the collected data only once, so here we
        #   clear the buffer

    For the scenario of collecting data from multiple environments to a single
    buffer, the cache buffers will turn on automatically. It may return the
    data more than the given limitation.

    .. note::

        Please make sure the given environment has a time limitation.
    def __init__(
        policy: BasePolicy,
        env: Union[gym.Env, BaseVectorEnv],
        buffer: Optional[ReplayBuffer] = None,
        preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
        stat_size: Optional[int] = 100,
        action_noise: Optional[BaseNoise] = None,
        reward_metric: Optional[Callable[[np.ndarray], float]] = None,
    ) -> None:
        self.env = env
        self.env_num = 1
        self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
        self.buffer = buffer
        self.policy = policy
        self.preprocess_fn = preprocess_fn
        self.process_fn = policy.process_fn
        self._multi_env = isinstance(env, BaseVectorEnv)
        # need multiple cache buffers only if storing in one buffer
        self._cached_buf = []
        if self._multi_env:
            self.env_num = len(env)
            self._cached_buf = [
                ListReplayBuffer() for _ in range(self.env_num)
        self.stat_size = stat_size
        self._action_noise = action_noise

        self._rew_metric = reward_metric or Collector._default_rew_metric

    def _default_rew_metric(x):
        # this internal function is designed for single-agent RL
        # for multi-agent RL, a reward_metric must be provided
        assert np.asanyarray(x).size == 1, \
            'Please specify the reward_metric ' \
            'since the reward is not a scalar.'
        return x

    def reset(self) -> None:
        """Reset all related variables in the collector."""
        self.data = Batch(state={},
        self.step_speed = MovAvg(self.stat_size)
        self.episode_speed = MovAvg(self.stat_size)
        self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
        if self._action_noise is not None:

    def reset_buffer(self) -> None:
        """Reset the main data buffer."""
        if self.buffer is not None:

    def get_env_num(self) -> int:
        """Return the number of environments the collector have."""
        return self.env_num

    def reset_env(self) -> None:
        """Reset all of the environment(s)' states and reset all of the cache
        buffers (if need).
        obs = self.env.reset()
        if not self._multi_env:
            obs = self._make_batch(obs)
        if self.preprocess_fn:
            obs = self.preprocess_fn(obs=obs).get('obs', obs)
        self.data.obs = obs
        self.reward = 0.  # will be specified when the first data is ready
        self.length = np.zeros(self.env_num)
        for b in self._cached_buf:

    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
        """Reset all the seed(s) of the given environment(s)."""
        return self.env.seed(seed)

    def render(self, **kwargs) -> None:
        """Render all the environment(s)."""
        return self.env.render(**kwargs)

    def close(self) -> None:
        """Close the environment(s)."""

    def _make_batch(self, data: Any) -> np.ndarray:
        """Return [data]."""
        if isinstance(data, np.ndarray):
            return data[None]
            return np.array([data])

    def _reset_state(self, id: Union[int, List[int]]) -> None:
        """Reset self.data.state[id]."""
        state = self.data.state  # it is a reference
        if isinstance(state, torch.Tensor):
        elif isinstance(state, np.ndarray):
            state[id] = None if state.dtype == np.object else 0
        elif isinstance(state, Batch):

    def collect(
            n_step: int = 0,
            n_episode: Union[int, List[int]] = 0,
            random: bool = False,
            render: Optional[float] = None,
            log_fn: Optional[Callable[[dict],
                                      None]] = None) -> Dict[str, float]:
        """Collect a specified number of step or episode.

        :param int n_step: how many steps you want to collect.
        :param n_episode: how many episodes you want to collect (in each
        :type n_episode: int or list
        :param bool random: whether to use random policy for collecting data,
            defaults to ``False``.
        :param float render: the sleep time between rendering consecutive
            frames, defaults to ``None`` (no rendering).
        :param function log_fn: a function which receives env info, typically
            for tensorboard logging.

        .. note::

            One and only one collection number specification is permitted,
            either ``n_step`` or ``n_episode``.

        :return: A dict including the following keys

            * ``n/ep`` the collected number of episodes.
            * ``n/st`` the collected number of steps.
            * ``v/st`` the speed of steps per second.
            * ``v/ep`` the speed of episode per second.
            * ``rew`` the mean reward over collected episodes.
            * ``len`` the mean length over collected episodes.
        if not self._multi_env:
            n_episode = np.sum(n_episode)
        start_time = time.time()
        assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
            "One and only one collection number specification is permitted!"
        cur_step, cur_episode = 0, np.zeros(self.env_num)
        reward_sum, length_sum = 0., 0

        # change
        ty1_succ_rate_1 = 0.
        ty1_succ_rate_2 = 0.
        ty1_succ_rate_3 = 0.
        ty1_succ_rate_4 = 0.
        Q_len_1 = 0.
        Q_len_2 = 0.
        Q_len_3 = 0.
        Q_len_4 = 0.
        energy_effi_1 = 0.
        energy_effi_2 = 0.
        energy_effi_3 = 0.
        energy_effi_4 = 0.
        avg_rate = 0.
        avg_power = 0.

        while True:
            if cur_step >= 100000 and cur_episode.sum() == 0:
                    'There are already many steps in an episode. '
                    'You should add a time limitation to your environment!',

            # restore the state and the input data
            last_state = self.data.state
            if last_state.is_empty():
                last_state = None
            self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())

            # calculate the next action
            if random:
                action_space = self.env.action_space
                if isinstance(action_space, list):
                    result = Batch(act=[a.sample() for a in action_space])
                    result = Batch(act=self._make_batch(action_space.sample()))
                with torch.no_grad():
                    result = self.policy(self.data, last_state)

            # convert None to Batch(), since None is reserved for 0-init
            state = result.get('state', Batch())
            if state is None:
                state = Batch()
            self.data.state = state
            if hasattr(result, 'policy'):
                self.data.policy = to_numpy(result.policy)
            # save hidden state to policy._state, in order to save into buffer
            self.data.policy._state = self.data.state

            self.data.act = to_numpy(result.act)
            if self._action_noise is not None:
                self.data.act += self._action_noise(self.data.act.shape)

            # step in env
            obs_next, rew, done, info = self.env.step(
                self.data.act if self._multi_env else self.data.act[0])

            # move data to self.data
            if not self._multi_env:
                obs_next = self._make_batch(obs_next)
                rew = self._make_batch(rew)
                done = self._make_batch(done)
                info = self._make_batch(info)
            self.data.obs_next = obs_next
            self.data.rew = rew
            self.data.done = done
            self.data.info = info

            if log_fn:
                log_fn(info if self._multi_env else info[0])
            if render:
                if render > 0:

            # add data into the buffer
            self.length += 1
            self.reward += self.data.rew
            if self.preprocess_fn:
                result = self.preprocess_fn(**self.data)
            if self._multi_env:  # cache_buffer branch
                # change
                if self.data.done[0]:
                    ty1_succ_rate_1 += self.data.info[0]['ty1_succ_rate_1']
                    ty1_succ_rate_2 += self.data.info[0]['ty1_succ_rate_2']
                    ty1_succ_rate_3 += self.data.info[0]['ty1_succ_rate_3']
                    ty1_succ_rate_4 += self.data.info[0]['ty1_succ_rate_4']
                    Q_len_1 += self.data.info[0]['Q_len_1']
                    Q_len_2 += self.data.info[0]['Q_len_2']
                    Q_len_3 += self.data.info[0]['Q_len_3']
                    Q_len_4 += self.data.info[0]['Q_len_4']
                    energy_effi_1 += self.data.info[0]['energy_effi_1']
                    energy_effi_2 += self.data.info[0]['energy_effi_2']
                    energy_effi_3 += self.data.info[0]['energy_effi_3']
                    energy_effi_4 += self.data.info[0]['energy_effi_4']
                    avg_rate += self.data.info[0]['avg_rate']
                    avg_power += self.data.info[0]['avg_power']
                for i in range(self.env_num):
                    if self.data.done[i]:
                        if n_step != 0 or np.isscalar(n_episode) or \
                                cur_episode[i] < n_episode[i]:
                            cur_episode[i] += 1
                            reward_sum += self.reward[i]
                            length_sum += self.length[i]
                            if self._cached_buf:
                                cur_step += len(self._cached_buf[i])
                                if self.buffer is not None:
                        self.reward[i], self.length[i] = 0., 0
                        if self._cached_buf:
                obs_next = self.data.obs_next
                if sum(self.data.done):
                    env_ind = np.where(self.data.done)[0]
                    obs_reset = self.env.reset(env_ind)
                    if self.preprocess_fn:
                        obs_next[env_ind] = self.preprocess_fn(
                            obs=obs_reset).get('obs', obs_reset)
                        obs_next[env_ind] = obs_reset
                self.data.obs_next = obs_next
                if n_episode != 0:
                    if isinstance(n_episode, list) and \
                            (cur_episode >= np.array(n_episode)).all() or \
                            np.isscalar(n_episode) and \
                            cur_episode.sum() >= n_episode:
            else:  # single buffer, without cache_buffer
                if self.buffer is not None:
                cur_step += 1
                if self.data.done[0]:
                    # change
                    ty1_succ_rate_1 += self.data.info['ty1_succ_rate_1']
                    ty1_succ_rate_2 += self.data.info['ty1_succ_rate_2']
                    ty1_succ_rate_3 += self.data.info['ty1_succ_rate_3']
                    ty1_succ_rate_4 += self.data.info['ty1_succ_rate_4']
                    Q_len_1 += self.data.info['Q_len_1']
                    Q_len_2 += self.data.info['Q_len_2']
                    Q_len_3 += self.data.info['Q_len_3']
                    Q_len_4 += self.data.info['Q_len_4']
                    energy_effi_1 += self.data.info['energy_effi_1']
                    energy_effi_2 += self.data.info['energy_effi_2']
                    energy_effi_3 += self.data.info['energy_effi_3']
                    energy_effi_4 += self.data.info['energy_effi_4']
                    avg_rate += self.data.info[0]['avg_rate']
                    avg_power += self.data.info[0]['avg_power']
                    cur_episode += 1
                    reward_sum += self.reward[0]
                    length_sum += self.length[0]
                    self.reward, self.length = 0., np.zeros(self.env_num)
                    self.data.state = Batch()
                    obs_next = self._make_batch(self.env.reset())
                    if self.preprocess_fn:
                        obs_next = self.preprocess_fn(obs=obs_next).get(
                            'obs', obs_next)
                    self.data.obs_next = obs_next
                if n_episode != 0 and cur_episode >= n_episode:
            if n_step != 0 and cur_step >= n_step:
            self.data.obs = self.data.obs_next
        self.data.obs = self.data.obs_next

        # generate the statistics
        cur_episode = sum(cur_episode)
        duration = max(time.time() - start_time, 1e-9)
        self.step_speed.add(cur_step / duration)
        self.episode_speed.add(cur_episode / duration)
        self.collect_step += cur_step
        self.collect_episode += cur_episode
        self.collect_time += duration
        if isinstance(n_episode, list):
            n_episode = np.sum(n_episode)
            n_episode = max(cur_episode, 1)
        reward_sum /= n_episode
        if np.asanyarray(reward_sum).size > 1:  # non-scalar reward_sum
            reward_sum = self._rew_metric(reward_sum)
        # change
        return {
            'n/ep': cur_episode,
            'n/st': cur_step,
            'v/st': self.step_speed.get(),
            'v/ep': self.episode_speed.get(),
            'rew': reward_sum,
            'len': length_sum / n_episode,
            'ty1s_1': ty1_succ_rate_1,
            'ty1s_2': ty1_succ_rate_2,
            'ty1s_3': ty1_succ_rate_3,
            'ty1s_4': ty1_succ_rate_4,
            'ql_1': Q_len_1,
            'ql_2': Q_len_2,
            'ql_3': Q_len_3,
            'ql_4': Q_len_4,
            'ee_1': energy_effi_1,
            'ee_2': energy_effi_2,
            'ee_3': energy_effi_3,
            'ee_4': energy_effi_4,
            'avg_r': avg_rate,
            'avg_p': avg_power,

    def sample(self, batch_size: int) -> Batch:
        """Sample a data batch from the internal replay buffer. It will call
        :meth:`~tianshou.policy.BasePolicy.process_fn` before returning
        the final batch data.

        :param int batch_size: ``0`` means it will extract all the data from
            the buffer, otherwise it will extract the data with the given
        batch_data, indice = self.buffer.sample(batch_size)
        batch_data = self.process_fn(batch_data, self.buffer, indice)
        return batch_data
Exemplo n.º 7
class Collector(object):
    """Collector enables the policy to interact with different types of envs.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
    :param env: a ``gym.Env`` environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
        class. If set to ``None`` (testing phase), it will not store the data.
    :param function preprocess_fn: a function called before the data has been
        added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
        to None.
    :param BaseNoise action_noise: add a noise to continuous action. Normally
        a policy already has a noise param for exploration in training phase,
        so this is recommended to use in test collector for some purpose.
    :param function reward_metric: to be used in multi-agent RL. The reward to
        report is of shape [agent_num], but we need to return a single scalar
        to monitor training. This function specifies what is the desired
        metric, e.g., the reward of agent 1 or the average reward over all
        agents. By default, the behavior is to select the reward of agent 1.

    The ``preprocess_fn`` is a function called before the data has been added
    to the buffer with batch format, which receives up to 7 keys as listed in
    :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the
    collector resets the environment. It returns either a dict or a
    :class:`~tianshou.data.Batch` with the modified keys and values. Examples
    are in "test/base/test_collector.py".

    Here is the example:

        policy = PGPolicy(...)  # or other policies if you wish
        env = gym.make('CartPole-v0')
        replay_buffer = ReplayBuffer(size=10000)
        # here we set up a collector with a single environment
        collector = Collector(policy, env, buffer=replay_buffer)

        # the collector supports vectorized environments as well
        envs = DummyVectorEnv([lambda: gym.make('CartPole-v0')
                               for _ in range(3)])
        collector = Collector(policy, envs, buffer=replay_buffer)

        # collect 3 episodes
        # collect 1 episode for the first env, 3 for the third env
        collector.collect(n_episode=[1, 0, 3])
        # collect at least 2 steps
        # collect episodes with visual rendering (the render argument is the
        #   sleep time between rendering consecutive frames)
        collector.collect(n_episode=1, render=0.03)

    Collected data always consist of full episodes. So if only ``n_step``
    argument is give, the collector may return the data more than the
    ``n_step`` limitation. Same as ``n_episode`` for the multiple environment

    .. note::

        Please make sure the given environment has a time limitation.

    def __init__(
        policy: BasePolicy,
        env: Union[gym.Env, BaseVectorEnv],
        buffer: Optional[ReplayBuffer] = None,
        preprocess_fn: Optional[Callable[..., Batch]] = None,
        action_noise: Optional[BaseNoise] = None,
        reward_metric: Optional[Callable[[np.ndarray], float]] = None,
    ) -> None:
        if not isinstance(env, BaseVectorEnv):
            env = DummyVectorEnv([lambda: env])
        self.env = env
        self.env_num = len(env)
        # environments that are available in step()
        # this means all environments in synchronous simulation
        # but only a subset of environments in asynchronous simulation
        self._ready_env_ids = np.arange(self.env_num)
        # self.async is a flag to indicate whether this collector works
        # with asynchronous simulation
        self.is_async = env.is_async
        # need cache buffers before storing in the main buffer
        self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
        self.buffer = buffer
        self.policy = policy
        self.preprocess_fn = preprocess_fn
        self.process_fn = policy.process_fn
        self._action_space = env.action_space
        self._action_noise = action_noise
        self._rew_metric = reward_metric or Collector._default_rew_metric
        # avoid creating attribute outside __init__

    def _default_rew_metric(
        x: Union[Number, np.number]
    ) -> Union[Number, np.number]:
        # this internal function is designed for single-agent RL
        # for multi-agent RL, a reward_metric must be provided
        assert np.asanyarray(x).size == 1, (
            "Please specify the reward_metric "
            "since the reward is not a scalar."
        return x

    def reset(self) -> None:
        """Reset all related variables in the collector."""
        # use empty Batch for ``state`` so that ``self.data`` supports slicing
        # convert empty Batch to None when passing data to policy
        self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
                          obs_next={}, policy={})
        if self._action_noise is not None:

    def reset_stat(self) -> None:
        """Reset the statistic variables."""
        self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0

    def reset_buffer(self) -> None:
        """Reset the main data buffer."""
        if self.buffer is not None:

    def get_env_num(self) -> int:
        """Return the number of environments the collector have."""
        return self.env_num

    def reset_env(self) -> None:
        """Reset all of the environment(s)' states and the cache buffers."""
        self._ready_env_ids = np.arange(self.env_num)
        obs = self.env.reset()
        if self.preprocess_fn:
            obs = self.preprocess_fn(obs=obs).get("obs", obs)
        self.data.obs = obs
        for b in self._cached_buf:

    def _reset_state(self, id: Union[int, List[int]]) -> None:
        """Reset the hidden state: self.data.state[id]."""
        state = self.data.state  # it is a reference
        if isinstance(state, torch.Tensor):
        elif isinstance(state, np.ndarray):
            state[id] = None if state.dtype == np.object else 0
        elif isinstance(state, Batch):

    def collect(
        n_step: Optional[int] = None,
        n_episode: Optional[Union[int, List[int]]] = None,
        random: bool = False,
        render: Optional[float] = None,
        no_grad: bool = True,
    ) -> Dict[str, float]:
        """Collect a specified number of step or episode.

        :param int n_step: how many steps you want to collect.
        :param n_episode: how many episodes you want to collect. If it is an
            int, it means to collect at lease ``n_episode`` episodes; if it is
            a list, it means to collect exactly ``n_episode[i]`` episodes in
            the i-th environment
        :param bool random: whether to use random policy for collecting data,
            defaults to False.
        :param float render: the sleep time between rendering consecutive
            frames, defaults to None (no rendering).
        :param bool no_grad: whether to retain gradient in policy.forward,
            defaults to True (no gradient retaining).

        .. note::

            One and only one collection number specification is permitted,
            either ``n_step`` or ``n_episode``.

        :return: A dict including the following keys

            * ``n/ep`` the collected number of episodes.
            * ``n/st`` the collected number of steps.
            * ``v/st`` the speed of steps per second.
            * ``v/ep`` the speed of episode per second.
            * ``rew`` the mean reward over collected episodes.
            * ``len`` the mean length over collected episodes.
        assert (n_step is not None and n_episode is None and n_step > 0) or (
            n_step is None and n_episode is not None and np.sum(n_episode) > 0
        ), "Only one of n_step or n_episode is allowed in Collector.collect, "
        f"got n_step = {n_step}, n_episode = {n_episode}."
        start_time = time.time()
        step_count = 0
        # episode of each environment
        episode_count = np.zeros(self.env_num)
        # If n_episode is a list, and some envs have collected the required
        # number of episodes, these envs will be recorded in this list, and
        # they will not be stepped.
        finished_env_ids = []
        rewards = []
        whole_data = Batch()
        if isinstance(n_episode, list):
            assert len(n_episode) == self.get_env_num()
            finished_env_ids = [
                i for i in self._ready_env_ids if n_episode[i] <= 0]
            self._ready_env_ids = np.array(
                [x for x in self._ready_env_ids if x not in finished_env_ids])
        while True:
            if step_count >= 100000 and episode_count.sum() == 0:
                    "There are already many steps in an episode. "
                    "You should add a time limitation to your environment!",

            is_async = self.is_async or len(finished_env_ids) > 0
            if is_async:
                # self.data are the data for all environments in async
                # simulation or some envs have finished,
                # **only a subset of data are disposed**,
                # so we store the whole data in ``whole_data``, let self.data
                # to be the data available in ready environments, and finally
                # set these back into all the data
                whole_data = self.data
                self.data = self.data[self._ready_env_ids]

            # restore the state and the input data
            last_state = self.data.state
            if isinstance(last_state, Batch) and last_state.is_empty():
                last_state = None
            self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())

            # calculate the next action
            if random:
                spaces = self._action_space
                result = Batch(
                    act=[spaces[i].sample() for i in self._ready_env_ids])
                if no_grad:
                    with torch.no_grad():  # faster than retain_grad version
                        result = self.policy(self.data, last_state)
                    result = self.policy(self.data, last_state)

            state = result.get("state", Batch())
            # convert None to Batch(), since None is reserved for 0-init
            if state is None:
                state = Batch()
            self.data.update(state=state, policy=result.get("policy", Batch()))
            # save hidden state to policy._state, in order to save into buffer
            if not (isinstance(state, Batch) and state.is_empty()):
                self.data.policy._state = self.data.state

            self.data.act = to_numpy(result.act)
            if self._action_noise is not None:
                assert isinstance(self.data.act, np.ndarray)
                self.data.act += self._action_noise(self.data.act.shape)

            # step in env
            if not is_async:
                obs_next, rew, done, info = self.env.step(self.data.act)
                # store computed actions, states, etc
                    whole_data, self._ready_env_ids, self.data, self.env_num)
                # fetch finished data
                obs_next, rew, done, info = self.env.step(
                    self.data.act, id=self._ready_env_ids)
                self._ready_env_ids = np.array([i["env_id"] for i in info])
                # get the stepped data
                self.data = whole_data[self._ready_env_ids]
            # move data to self.data
            self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)

            if render:

            # add data into the buffer
            if self.preprocess_fn:
                result = self.preprocess_fn(**self.data)  # type: ignore

            for j, i in enumerate(self._ready_env_ids):
                # j is the index in current ready_env_ids
                # i is the index in all environments
                if self.buffer is None:
                    # users do not want to store data, so we store
                    # small fake data here to make the code clean
                    self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0)

                if done[j]:
                    if not (isinstance(n_episode, list)
                            and episode_count[i] >= n_episode[i]):
                        episode_count[i] += 1
                            np.sum(self._cached_buf[i].rew, axis=0)))
                        step_count += len(self._cached_buf[i])
                        if self.buffer is not None:
                        if isinstance(n_episode, list) and \
                                episode_count[i] >= n_episode[i]:
                            # env i has collected enough data, it has finished
            obs_next = self.data.obs_next
            if sum(done):
                env_ind_local = np.where(done)[0]
                env_ind_global = self._ready_env_ids[env_ind_local]
                obs_reset = self.env.reset(env_ind_global)
                if self.preprocess_fn:
                    obs_reset = self.preprocess_fn(
                        obs=obs_reset).get("obs", obs_reset)
                obs_next[env_ind_local] = obs_reset
            self.data.obs = obs_next
            if is_async:
                # set data back
                whole_data = deepcopy(whole_data)  # avoid reference in ListBuf
                    whole_data, self._ready_env_ids, self.data, self.env_num)
                # let self.data be the data in all environments again
                self.data = whole_data
            self._ready_env_ids = np.array(
                [x for x in self._ready_env_ids if x not in finished_env_ids])
            if n_step:
                if step_count >= n_step:
                if isinstance(n_episode, int) and \
                        episode_count.sum() >= n_episode:
                if isinstance(n_episode, list) and \
                        (episode_count >= n_episode).all():

        # finished envs are ready, and can be used for the next collection
        self._ready_env_ids = np.array(
            self._ready_env_ids.tolist() + finished_env_ids)

        # generate the statistics
        episode_count = sum(episode_count)
        duration = max(time.time() - start_time, 1e-9)
        self.collect_step += step_count
        self.collect_episode += episode_count
        self.collect_time += duration
        return {
            "n/ep": episode_count,
            "n/st": step_count,
            "v/st": step_count / duration,
            "v/ep": episode_count / duration,
            "rew": np.mean(rewards),
            "rew_std": np.std(rewards),
            "len": step_count / episode_count,
Exemplo n.º 8
def test_batch():
    assert list(Batch()) == []
    assert Batch().is_empty()
    assert not Batch(b={'c': {}}).is_empty()
    assert Batch(b={'c': {}}).is_empty(recurse=True)
    assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
    assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
    assert not Batch(d=1).is_empty()
    assert not Batch(a=np.float64(1.0)).is_empty()
    assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
    assert not Batch(a=[1, 2, 3]).is_empty()
    b = Batch()
    assert b.is_empty()
    b.update(c=[3, 5])
    assert np.allclose(b.c, [3, 5])
    # mimic the behavior of dict.update, where kwargs can overwrite keys
    b.update({'a': 2}, a=3)
    assert b.a == 3
    with pytest.raises(AssertionError):
        Batch({1: 2})
    with pytest.raises(TypeError):
        Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
    batch = Batch(a=[torch.ones(3), torch.ones(3)])
    assert torch.allclose(batch.a, torch.ones(2, 3))
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    assert batch.obs == batch["obs"]
    batch.obs = [1]
    assert batch.obs == [1]
    assert np.allclose(batch.obs, [1, 1])
    assert batch.np.shape == (6, 4)
    assert np.allclose(batch[0].obs, batch[1].obs)
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, shuffle=False)):
        if i != 5:
            assert b.obs == batch[i].obs
            with pytest.raises(AttributeError):
            with pytest.raises(AttributeError):
    batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
    batch_item = Batch({'a': [batch_dict]})[0]
    assert isinstance(batch_item.a.b, np.ndarray)
    assert batch_item.a.b == batch_dict['b']
    assert isinstance(batch_item.a.c, float)
    assert batch_item.a.c == batch_dict['c']
    assert isinstance(batch_item.a.d, torch.Tensor)
    assert batch_item.a.d == batch_dict['d']
    batch2 = Batch(a=[{
        'b': np.float64(1.0),
        'c': np.zeros(1),
        'd': Batch(e=np.array(3.0))
    assert len(batch2) == 1
    assert Batch().shape == []
    assert Batch(a=1).shape == []
    assert batch2.shape[0] == 1
    with pytest.raises(IndexError):
    with pytest.raises(IndexError):
    assert batch2[0].shape == []
    with pytest.raises(IndexError):
    with pytest.raises(TypeError):
    assert isinstance(batch2[0].a.c, np.ndarray)
    assert isinstance(batch2[0].a.b, np.float64)
    assert isinstance(batch2[0].a.d.e, np.float64)
    batch2_from_list = Batch(list(batch2))
    batch2_from_comp = Batch([e for e in batch2])
    assert batch2_from_list.a.b == batch2.a.b
    assert batch2_from_list.a.c == batch2.a.c
    assert batch2_from_list.a.d.e == batch2.a.d.e
    assert batch2_from_comp.a.b == batch2.a.b
    assert batch2_from_comp.a.c == batch2.a.c
    assert batch2_from_comp.a.d.e == batch2.a.d.e
    for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
        assert batch_slice.a.b == batch2.a.b
        assert batch_slice.a.c == batch2.a.c
        assert batch_slice.a.d.e == batch2.a.d.e
    batch2_sum = (batch2 + 1.0) * 2
    assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
    assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
    assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
    batch3 = Batch(a={
        'c': np.zeros(1),
        'd': Batch(e=np.array([0.0]), f=np.array([3.0]))
    batch3.a.d[0] = {'e': 4.0}
    assert batch3.a.d.e[0] == 4.0
    batch3.a.d[0] = Batch(f=5.0)
    assert batch3.a.d.f[0] == 5.0
    with pytest.raises(KeyError):
        batch3.a.d[0] = Batch(f=5.0, g=0.0)
    # auto convert
    batch4 = Batch(a=np.array(['a', 'b']))
    assert batch4.a.dtype == np.object  # auto convert to np.object
    batch4.update(a=np.array(['c', 'd']))
    assert list(batch4.a) == ['c', 'd']
    assert batch4.a.dtype == np.object  # auto convert to np.object
    batch5 = Batch(a=np.array([{'index': 0}]))
    assert isinstance(batch5.a, Batch)
    assert np.allclose(batch5.a.index, [0])
    batch5.b = np.array([{'index': 1}])
    assert isinstance(batch5.b, Batch)
    assert np.allclose(batch5.b.index, [1])

    # None is a valid object and can be stored in Batch
    a = Batch.stack([Batch(a=None), Batch(b=None)])
    assert a.a[0] is None and a.a[1] is None
    assert a.b[0] is None and a.b[1] is None
Exemplo n.º 9
class Collector(object):
    """The :class:`~tianshou.data.Collector` enables the policy to interact
    with different types of environments conveniently.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
    :param env: a ``gym.Env`` environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
        class. If set to ``None`` (testing phase), it will not store the data.
    :param function preprocess_fn: a function called before the data has been
        added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
        to ``None``.
    :param BaseNoise action_noise: add a noise to continuous action. Normally
        a policy already has a noise param for exploration in training phase,
        so this is recommended to use in test collector for some purpose.
    :param function reward_metric: to be used in multi-agent RL. The reward to
        report is of shape [agent_num], but we need to return a single scalar
        to monitor training. This function specifies what is the desired
        metric, e.g., the reward of agent 1 or the average reward over all
        agents. By default, the behavior is to select the reward of agent 1.

    The ``preprocess_fn`` is a function called before the data has been added
    to the buffer with batch format, which receives up to 7 keys as listed in
    :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the
    collector resets the environment. It returns either a dict or a
    :class:`~tianshou.data.Batch` with the modified keys and values. Examples
    are in "test/base/test_collector.py".


        policy = PGPolicy(...)  # or other policies if you wish
        env = gym.make('CartPole-v0')
        replay_buffer = ReplayBuffer(size=10000)
        # here we set up a collector with a single environment
        collector = Collector(policy, env, buffer=replay_buffer)

        # the collector supports vectorized environments as well
        envs = DummyVectorEnv([lambda: gym.make('CartPole-v0')
                               for _ in range(3)])
        collector = Collector(policy, envs, buffer=replay_buffer)

        # collect 3 episodes
        # collect 1 episode for the first env, 3 for the third env
        collector.collect(n_episode=[1, 0, 3])
        # collect at least 2 steps
        # collect episodes with visual rendering (the render argument is the
        #   sleep time between rendering consecutive frames)
        collector.collect(n_episode=1, render=0.03)

    Collected data always consist of full episodes. So if only ``n_step``
    argument is give, the collector may return the data more than the
    ``n_step`` limitation. Same as ``n_episode`` for the multiple environment

    .. note::

        Please make sure the given environment has a time limitation.
    def __init__(
        policy: BasePolicy,
        env: Union[gym.Env, BaseVectorEnv],
        buffer: Optional[ReplayBuffer] = None,
        preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
        action_noise: Optional[BaseNoise] = None,
        reward_metric: Optional[Callable[[np.ndarray], float]] = None,
    ) -> None:
        if not isinstance(env, BaseVectorEnv):
            env = DummyVectorEnv([lambda: env])
        self.env = env
        self.env_num = len(env)
        # environments that are available in step()
        # this means all environments in synchronous simulation
        # but only a subset of environments in asynchronous simulation
        self._ready_env_ids = np.arange(self.env_num)
        # self.async is a flag to indicate whether this collector works
        # with asynchronous simulation
        self.is_async = env.is_async
        # need cache buffers before storing in the main buffer
        self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
        self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
        self.buffer = buffer
        self.policy = policy
        self.preprocess_fn = preprocess_fn
        self.process_fn = policy.process_fn
        self._action_space = env.action_space
        self._action_noise = action_noise
        self._rew_metric = reward_metric or Collector._default_rew_metric
        # avoid creating attribute outside __init__
        self.data = Batch(state={},

    def _default_rew_metric(x):
        # this internal function is designed for single-agent RL
        # for multi-agent RL, a reward_metric must be provided
        assert np.asanyarray(x).size == 1, \
            'Please specify the reward_metric ' \
            'since the reward is not a scalar.'
        return x

    def reset(self) -> None:
        """Reset all related variables in the collector."""
        # use empty Batch for ``state`` so that ``self.data`` supports slicing
        # convert empty Batch to None when passing data to policy
        self.data = Batch(state={},
        self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
        if self._action_noise is not None:

    def reset_buffer(self) -> None:
        """Reset the main data buffer."""
        if self.buffer is not None:

    def get_env_num(self) -> int:
        """Return the number of environments the collector have."""
        return self.env_num

    def reset_env(self) -> None:
        """Reset all of the environment(s)' states and reset all of the cache
        buffers (if need).
        self._ready_env_ids = np.arange(self.env_num)
        obs = self.env.reset()
        if self.preprocess_fn:
            obs = self.preprocess_fn(obs=obs).get('obs', obs)
        self.data.obs = obs
        for b in self._cached_buf:

    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
        """Reset all the seed(s) of the given environment(s)."""
        return self.env.seed(seed)

    def render(self, **kwargs) -> None:
        """Render all the environment(s)."""
        return self.env.render(**kwargs)

    def _reset_state(self, id: Union[int, List[int]]) -> None:
        """Reset the hidden state: self.data.state[id]."""
        state = self.data.state  # it is a reference
        if isinstance(state, torch.Tensor):
        elif isinstance(state, np.ndarray):
            state[id] = None if state.dtype == np.object else 0
        elif isinstance(state, Batch):

    def collect(
        n_step: Optional[int] = None,
        n_episode: Optional[Union[int, List[int]]] = None,
        random: bool = False,
        render: Optional[float] = None,
    ) -> Dict[str, float]:
        """Collect a specified number of step or episode.

        :param int n_step: how many steps you want to collect.
        :param n_episode: how many episodes you want to collect. If it is an
            int, it means to collect at lease ``n_episode`` episodes; if it is
            a list, it means to collect exactly ``n_episode[i]`` episodes in
            the i-th environment
        :param bool random: whether to use random policy for collecting data,
            defaults to ``False``.
        :param float render: the sleep time between rendering consecutive
            frames, defaults to ``None`` (no rendering).

        .. note::

            One and only one collection number specification is permitted,
            either ``n_step`` or ``n_episode``.

        :return: A dict including the following keys

            * ``n/ep`` the collected number of episodes.
            * ``n/st`` the collected number of steps.
            * ``v/st`` the speed of steps per second.
            * ``v/ep`` the speed of episode per second.
            * ``rew`` the mean reward over collected episodes.
            * ``len`` the mean length over collected episodes.
        assert (n_step and not n_episode) or (not n_step and n_episode), \
            "One and only one collection number specification is permitted!"
        start_time = time.time()
        step_count = 0
        # episode of each environment
        episode_count = np.zeros(self.env_num)
        reward_total = 0.0
        whole_data = Batch()
        while True:
            if step_count >= 100000 and episode_count.sum() == 0:
                    'There are already many steps in an episode. '
                    'You should add a time limitation to your environment!',

            if self.is_async:
                # self.data are the data for all environments
                # in async simulation, only a subset of data are disposed
                # so we store the whole data in ``whole_data``, let self.data
                # to be all the data available in ready environments, and
                # finally set these back into all the data
                whole_data = self.data
                self.data = self.data[self._ready_env_ids]

            # restore the state and the input data
            last_state = self.data.state
            if isinstance(last_state, Batch) and last_state.is_empty():
                last_state = None
            self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())

            # calculate the next action
            if random:
                spaces = self._action_space
                result = Batch(
                    act=[spaces[i].sample() for i in self._ready_env_ids])
                with torch.no_grad():
                    result = self.policy(self.data, last_state)

            state = result.get('state', Batch())
            # convert None to Batch(), since None is reserved for 0-init
            if state is None:
                state = Batch()
            self.data.update(state=state, policy=result.get('policy', Batch()))
            # save hidden state to policy._state, in order to save into buffer
            if not (isinstance(self.data.state, Batch)
                    and self.data.state.is_empty()):
                self.data.policy._state = self.data.state

            self.data.act = to_numpy(result.act)
            if self._action_noise is not None:
                self.data.act += self._action_noise(self.data.act.shape)

            # step in env
            if not self.is_async:
                obs_next, rew, done, info = self.env.step(self.data.act)
                # store computed actions, states, etc
                _batch_set_item(whole_data, self._ready_env_ids, self.data,
                # fetch finished data
                obs_next, rew, done, info = self.env.step(
                    action=self.data.act, id=self._ready_env_ids)
                self._ready_env_ids = np.array([i['env_id'] for i in info])
                # get the stepped data
                self.data = whole_data[self._ready_env_ids]
            # move data to self.data
            self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)

            if render:

            # add data into the buffer
            if self.preprocess_fn:
                result = self.preprocess_fn(**self.data)
            for j, i in enumerate(self._ready_env_ids):
                # j is the index in current ready_env_ids
                # i is the index in all environments
                if self.data.done[j]:
                    if n_step or np.isscalar(n_episode) or \
                            episode_count[i] < n_episode[i]:
                        episode_count[i] += 1
                        reward_total += np.sum(self._cached_buf[i].rew, axis=0)
                        step_count += len(self._cached_buf[i])
                        if self.buffer is not None:
            obs_next = self.data.obs_next
            if sum(self.data.done):
                env_ind_local = np.where(self.data.done)[0]
                env_ind_global = self._ready_env_ids[env_ind_local]
                obs_reset = self.env.reset(env_ind_global)
                if self.preprocess_fn:
                    obs_next[env_ind_local] = self.preprocess_fn(
                        obs=obs_reset).get('obs', obs_reset)
                    obs_next[env_ind_local] = obs_reset
            self.data.obs = obs_next
            if self.is_async:
                # set data back
                _batch_set_item(whole_data, self._ready_env_ids, self.data,
                # let self.data be the data in all environments again
                self.data = whole_data
            if n_step:
                if step_count >= n_step:
                if isinstance(n_episode, int) and \
                        episode_count.sum() >= n_episode:
                if isinstance(n_episode, list) and \
                        (episode_count >= n_episode).all():

        # generate the statistics
        episode_count = sum(episode_count)
        duration = max(time.time() - start_time, 1e-9)
        self.collect_step += step_count
        self.collect_episode += episode_count
        self.collect_time += duration
        # average reward across the number of episodes
        reward_avg = reward_total / episode_count
        if np.asanyarray(reward_avg).size > 1:  # non-scalar reward_avg
            reward_avg = self._rew_metric(reward_avg)
        return {
            'n/ep': episode_count,
            'n/st': step_count,
            'v/st': step_count / duration,
            'v/ep': episode_count / duration,
            'rew': reward_avg,
            'len': step_count / episode_count,

    def sample(self, batch_size: int) -> Batch:
        """Sample a data batch from the internal replay buffer. It will call
        :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the
        final batch data.

        :param int batch_size: ``0`` means it will extract all the data from
            the buffer, otherwise it will extract the data with the given
            'Collector.sample is deprecated and will cause error if you use '
            'prioritized experience replay! Collector.sample will be removed '
            'upon version 0.3. Use policy.update instead!', Warning)
        batch_data, indice = self.buffer.sample(batch_size)
        batch_data = self.process_fn(batch_data, self.buffer, indice)
        return batch_data

    def close(self) -> None:
            'Collector.close is deprecated and will be removed upon version '
            '0.3.', Warning)
Exemplo n.º 10
    def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                   indice: np.ndarray) -> Batch:
        batch = super().process_fn(batch, buffer, indice)
        step = batch.step
        done_cnt = batch.done_cnt
        rel_step = step / np.max(step) if step.any() else np.zeros_like(step)
        if self.bk_step:
            # convert bk step to forward
            rel_step = 1 - rel_step
        if self.reweigh_type == "hard":
            med = np.median(step)
            cond = step > med if self.bk_step else step < med
            weight = np.where(cond, self.tper_weight, 2 - self.tper_weight)
        elif self.reweigh_type == "linear":
            weight = self._calc_linear_weight(rel_step, self.l, self.h, self.k,
        elif self.reweigh_type == 'adaptive_linear':
            cur_low = np.clip(
                self.low_l + (self.low_h - self.low_l) /
                (self.t_e - self.t_s) * (self._iter - self.t_s), self.low_l,
            cur_high = np.clip(
                self.high_h + (self.high_l - self.high_h) /
                (self.t_e - self.t_s) * (self._iter - self.t_s), self.high_l,
            weight = self._calc_linear_weight(rel_step, cur_low, cur_high,
                                              self.k, self.b)
        elif self.reweigh_type == 'done_cnt_linear':
            rel_done_cnt = done_cnt / np.max(done_cnt)
            # The tajectory is newer with larger done counts, which can be understood as fewer learning steps
            pseudo_step = 1 - rel_done_cnt
            cur_low = np.clip(
                self.low_l + (self.low_h - self.low_l) * pseudo_step,
                self.low_l, self.low_h)
            cur_high = np.clip(
                self.high_h + (self.high_l - self.high_h) * pseudo_step,
                self.high_l, self.high_h)
            weight = self._calc_linear_weight(rel_step, cur_low, cur_high,
                                              self.k, self.b)
        elif self.reweigh_type == 'oracle':
            info = batch.info
            # assert "agent_pos" in info.keys()
            agent_pos = info["agent_pos"]
            reward = batch.rew
            done = batch.done
            action = batch.act
            # print(obs_next, reward, done)

            next_agent_pos = self._get_next_agent_pos(agent_pos, action)
            next_V = self._get_oracle_V(next_agent_pos)
            Qstar = reward + (1 - done) * next_V
            with torch.no_grad():
                Qs = self.forward(batch).logits.detach().cpu().numpy()
            Qk = []
            for Q, a in zip(Qs, action):
            weight = np.exp(-np.abs(Qk - Qstar))
            assert weight.shape[0] == rel_step.shape[0]
            weight = weight / np.sum(weight) * rel_step.shape[0]
        batch.update({"weight": weight})
        return batch
Exemplo n.º 11
class adversarial_training_collector(object):
    """Collector that defends an existing policy with adversarial training.
    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
    :param env: a ``gym.Env`` environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param obs_adv_atk: an instance of the :class:`~advertorch.attacks.base.Attack`
        class implementing an image adversarial attack.
    :param atk_frequency: float, how frequently attacking env observations
    :param test: bool, if True adversarial actions replace original actions
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
        class. If set to ``None`` (testing phase), it will not store the data.
    :param function preprocess_fn: a function called before the data has been
        added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
        to None.
    :param function reward_metric: to be used in multi-agent RL. The reward to
        report is of shape [agent_num], but we need to return a single scalar
        to monitor training. This function specifies what is the desired
        metric, e.g., the reward of agent 1 or the average reward over all
        agents. By default, the behavior is to select the reward of agent 1.
    :param atk_frequency: float, how frequently attacking env observations.
    Note: parallel or async envs are currently not supported
    def __init__(
            policy: BasePolicy,
            env: Union[gym.Env, BaseVectorEnv],
            obs_adv_atk: Attack,
            buffer: Optional[ReplayBuffer] = None,
            preprocess_fn: Optional[Callable[..., Batch]] = None,
            reward_metric: Optional[Callable[[np.ndarray], float]] = None,
            atk_frequency: float = 0.5,
            test: bool = False,
            device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ) -> None:
        if not isinstance(env, BaseVectorEnv):
            env = DummyVectorEnv([lambda: env])
        self.env = env
        self.env_num = len(env)
        self.device = device
        self.obs_adv_atk = obs_adv_atk
        self.obs_adv_atk.targeted = False
        self.atk_frequency = atk_frequency
        self.test = test
        # environments that are available in step()
        # this means all environments in synchronous simulation
        # but only a subset of environments in asynchronous simulation
        self._ready_env_ids = np.arange(self.env_num)
        # need cache buffers before storing in the main buffer
        self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
        self.buffer = buffer
        self.policy = policy
        self.preprocess_fn = preprocess_fn
        self.process_fn = policy.process_fn
        self._action_space = env.action_space
        self._rew_metric = reward_metric or adversarial_training_collector._default_rew_metric
        # avoid creating attribute outside __init__

    def _default_rew_metric(
            x: Union[Number, np.number]) -> Union[Number, np.number]:
        # this internal function is designed for single-agent RL
        # for multi-agent RL, a reward_metric must be provided
        assert np.asanyarray(x).size == 1, (
            "Please specify the reward_metric "
            "since the reward is not a scalar.")
        return x

    def reset(self) -> None:
        """Reset all related variables in the collector."""
        # use empty Batch for ``state`` so that ``self.data`` supports slicing
        # convert empty Batch to None when passing data to policy
        self.data = Batch(state={},

    def reset_stat(self) -> None:
        """Reset the statistic variables."""
        self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0

    def reset_buffer(self) -> None:
        """Reset the main data buffer."""
        if self.buffer is not None:

    def get_env_num(self) -> int:
        """Return the number of environments the collector have."""
        return self.env_num

    def reset_env(self) -> None:
        """Reset all of the environment(s)' states and the cache buffers."""
        self._ready_env_ids = np.arange(self.env_num)
        obs = self.env.reset()
        if self.preprocess_fn:
            obs = self.preprocess_fn(obs=obs).get("obs", obs)
        self.data.obs = obs
        for b in self._cached_buf:

    def _reset_state(self, id: Union[int, List[int]]) -> None:
        """Reset the hidden state: self.data.state[id]."""
        state = self.data.state  # it is a reference
        if isinstance(state, torch.Tensor):
        elif isinstance(state, np.ndarray):
            state[id] = None if state.dtype == np.object else 0
        elif isinstance(state, Batch):

    def collect(
        n_step: Optional[int] = None,
        n_episode: Optional[Union[int, List[int]]] = None,
        random: bool = False,
        render: Optional[float] = None,
        no_grad: bool = True,
    ) -> Dict[str, float]:
        """Collect a specified number of step or episode.
        :param int n_step: how many steps you want to collect.
        :param n_episode: how many episodes you want to collect. If it is an
            int, it means to collect at lease ``n_episode`` episodes; if it is
            a list, it means to collect exactly ``n_episode[i]`` episodes in
            the i-th environment
        :param bool random: whether to use random policy for collecting data,
            defaults to False.
        :param float render: the sleep time between rendering consecutive
            frames, defaults to None (no rendering).
        :param bool no_grad: whether to retain gradient in policy.forward,
            defaults to True (no gradient retaining).
        .. note::
            One and only one collection number specification is permitted,
            either ``n_step`` or ``n_episode``.
        :return: A dict including the following keys
            * ``n/ep`` the collected number of episodes.
            * ``n/st`` the collected number of steps.
            * ``v/st`` the speed of steps per second.
            * ``v/ep`` the speed of episode per second.
            * ``rew`` the mean reward over collected episodes.
            * ``len`` the mean length over collected episodes.
        assert (n_step is not None and n_episode is None and n_step > 0) or (
            n_step is None and n_episode is not None and np.sum(n_episode) > 0
        ), "Only one of n_step or n_episode is allowed in Collector.collect, "
        f"got n_step = {n_step}, n_episode = {n_episode}."
        start_time = time.time()
        step_count = 0
        succ_attacks = 0
        n_attacks = 0
        # episode of each environment
        episode_count = np.zeros(self.env_num)
        # If n_episode is a list, and some envs have collected the required
        # number of episodes, these envs will be recorded in this list, and
        # they will not be stepped.
        finished_env_ids = []
        rewards = []
        if isinstance(n_episode, list):
            assert len(n_episode) == self.get_env_num()
            finished_env_ids = [
                i for i in self._ready_env_ids if n_episode[i] <= 0
            self._ready_env_ids = np.array(
                [x for x in self._ready_env_ids if x not in finished_env_ids])
        while True:
            if step_count >= 100000 and episode_count.sum() == 0:
                    "There are already many steps in an episode. "
                    "You should add a time limitation to your environment!",

            # restore the state and the input data
            last_state = self.data.state
            if isinstance(last_state, Batch) and last_state.is_empty():
                last_state = None
            self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())

            # calculate the next action
            if random:
                spaces = self._action_space
                result = Batch(
                    act=[spaces[i].sample() for i in self._ready_env_ids])
                if no_grad:
                    with torch.no_grad():  # faster than retain_grad version
                        result = self.policy(self.data, last_state)
                    result = self.policy(self.data, last_state)

            state = result.get("state", Batch())
            # convert None to Batch(), since None is reserved for 0-init
            if state is None:
                state = Batch()
            self.data.update(state=state, policy=result.get("policy", Batch()))
            # save hidden state to policy._state, in order to save into buffer
            if not (isinstance(state, Batch) and state.is_empty()):
                self.data.policy._state = self.data.state

            self.data.act = to_numpy(result.act)

            x = rd.uniform(0, 1)
            if x < self.atk_frequency:
                ori_act = self.data.act
                adv_act, adv_obs = self.obs_attacks(self.data, ori_act)
                for j, i in enumerate(self._ready_env_ids):
                    if adv_act[i] != ori_act[i]:
                        succ_attacks += 1
                n_attacks += self.env_num
                )  # so that the adv obs will be inserted in the buffer
                if self.test:
                    self.data.act = adv_act

            # step in env
            obs_next, rew, done, info = self.env.step(self.data.act)

            # move data to self.data
            self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)

            if render:

            # add data into the buffer
            if self.preprocess_fn:
                result = self.preprocess_fn(**self.data)  # type: ignore

            for j, i in enumerate(self._ready_env_ids):
                # j is the index in current ready_env_ids
                # i is the index in all environments
                if self.buffer is None:
                    # users do not want to store data, so we store
                    # small fake data here to make the code clean
                    self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0)

                if done[j]:
                    if not (isinstance(n_episode, list)
                            and episode_count[i] >= n_episode[i]):
                        episode_count[i] += 1
                                np.sum(self._cached_buf[i].rew, axis=0)))
                        step_count += len(self._cached_buf[i])
                        if self.buffer is not None:
                        if isinstance(n_episode, list) and \
                                episode_count[i] >= n_episode[i]:
                            # env i has collected enough data, it has finished
            obs_next = self.data.obs_next
            if sum(done):
                env_ind_local = np.where(done)[0]
                env_ind_global = self._ready_env_ids[env_ind_local]
                obs_reset = self.env.reset(env_ind_global)
                if self.preprocess_fn:
                    obs_reset = self.preprocess_fn(obs=obs_reset).get(
                        "obs", obs_reset)
                obs_next[env_ind_local] = obs_reset
            self.data.obs = obs_next
            self._ready_env_ids = np.array(
                [x for x in self._ready_env_ids if x not in finished_env_ids])
            if n_step:
                if step_count >= n_step:
                if isinstance(n_episode, int) and \
                        episode_count.sum() >= n_episode:
                if isinstance(n_episode, list) and \
                        (episode_count >= n_episode).all():

        # finished envs are ready, and can be used for the next collection
        self._ready_env_ids = np.array(self._ready_env_ids.tolist() +

        # generate the statistics
        episode_count = sum(episode_count)
        duration = max(time.time() - start_time, 1e-9)
        self.collect_step += step_count
        self.collect_episode += episode_count
        self.collect_time += duration
        return {
            "n/ep": episode_count,
            "n/st": step_count,
            "v/st": step_count / duration,
            "v/ep": episode_count / duration,
            "rew": np.mean(rewards),
            "rew_std": np.std(rewards),
            "len": step_count / episode_count,
            'succ_atks(%)': succ_attacks / n_attacks if n_attacks > 0 else 0,

    def obs_attacks(self, data, target_action: List[int]):
        Performs an image adversarial attack on the observation stored in 'obs' respect to
        the action 'target_action' using the method defined in 'self.obs_adv_atk'
        data = deepcopy(data)
        obs = torch.FloatTensor(data.obs).to(
            self.device)  # convert observation to tensor
        act = torch.tensor(target_action).to(
            self.device)  # convert action to tensor
        adv_obs = self.obs_adv_atk.perturb(
            obs, act)  # create adversarial observation
        with torch.no_grad():
            adv_obs = adv_obs.cpu().detach().numpy()
            data.obs = adv_obs
            result = self.policy(data, last_state=None)
        return to_numpy(result.act), adv_obs