예제 #1
0
def test_replaybuffer(size=10, bufsize=20):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize)
    buf.update(buf)
    assert str(buf) == buf.__class__.__name__ + '()'
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(obs, [a], rew, done, obs_next, info)
        obs = obs_next
        assert len(buf) == min(bufsize, i + 1)
    with pytest.raises(ValueError):
        buf._add_to_buffer('rew', np.array([1, 2, 3]))
    assert buf.act.dtype == np.object
    assert isinstance(buf.act[0], list)
    data, indice = buf.sample(bufsize * 2)
    assert (indice < len(buf)).all()
    assert (data.obs < size).all()
    assert (0 <= data.done).all() and (data.done <= 1).all()
    b = ReplayBuffer(size=10)
    b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
    assert b.obs[0] == 1
    assert b.done[0] == 'str'
    assert np.all(b.obs[1:] == 0)
    assert np.all(b.done[1:] == np.array(None))
    assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
    assert np.all(b.info.a[1:] == 0)
    assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
    assert np.all(b.info.b.c[1:] == 0.0)
    with pytest.raises(IndexError):
        b[22]
    b = ListReplayBuffer()
    with pytest.raises(NotImplementedError):
        b.sample(0)
예제 #2
0
def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize, stack_num=stack_num)
    buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
    buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
    obs = env.reset(1)
    for _ in range(16):
        obs_next, rew, done, info = env.step(1)
        buf.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info))
        buf2.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info))
        buf3.add(
            Batch(obs=[obs, obs, obs],
                  act=1,
                  rew=rew,
                  done=done,
                  obs_next=[obs, obs],
                  info=info))
        obs = obs_next
        if done:
            obs = env.reset(1)
    indices = np.arange(len(buf))
    assert np.allclose(
        buf.get(indices, 'obs')[..., 0],
        [[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2],
         [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]])
    assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs'))
    assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs_next'))
    _, indices = buf2.sample(0)
    assert indices.tolist() == [2, 6]
    _, indices = buf2.sample(1)
    assert indices[0] in [2, 6]
    batch, indices = buf2.sample(-1)  # neg bsz -> no data
    assert indices.tolist() == [] and len(batch) == 0
    with pytest.raises(IndexError):
        buf[bufsize * 2]
예제 #3
0
def test_stack(size=5, bufsize=9, stack_num=4):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize, stack_num=stack_num)
    buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
    buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
    obs = env.reset(1)
    for i in range(16):
        obs_next, rew, done, info = env.step(1)
        buf.add(obs, 1, rew, done, None, info)
        buf2.add(obs, 1, rew, done, None, info)
        buf3.add([None, None, obs], 1, rew, done, [None, obs], info)
        obs = obs_next
        if done:
            obs = env.reset(1)
    indice = np.arange(len(buf))
    assert np.allclose(buf.get(indice, 'obs')[..., 0], [
        [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
        [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
        [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]])
    assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs'))
    assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next'))
    _, indice = buf2.sample(0)
    assert indice.tolist() == [2, 6]
    _, indice = buf2.sample(1)
    assert indice in [2, 6]
    with pytest.raises(IndexError):
        buf[bufsize * 2]
예제 #4
0
def test_nstep_returns(size=10000):
    buf = ReplayBuffer(10)
    for i in range(12):
        buf.add(obs=0, act=0, rew=i + 1, done=i % 4 == 3)
    batch, indice = buf.sample(0)
    assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
    # rew:  [10, 11, 2, 3, 4, 5, 6, 7, 8, 9]
    # done: [ 0,  1, 0, 1, 0, 0, 0, 1, 0, 0]
    # test nstep = 1
    returns = to_numpy(BasePolicy.compute_nstep_return(
        batch, buf, indice, target_q_fn, gamma=.1, n_step=1).pop('returns'))
    assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
    r_ = compute_nstep_return_base(1, .1, buf, indice)
    assert np.allclose(returns, r_), (r_, returns)
    returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
        batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1
    ).pop('returns'))
    assert np.allclose(returns_multidim, returns[:, np.newaxis])
    # test nstep = 2
    returns = to_numpy(BasePolicy.compute_nstep_return(
        batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns'))
    assert np.allclose(returns, [
        3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
    r_ = compute_nstep_return_base(2, .1, buf, indice)
    assert np.allclose(returns, r_)
    returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
        batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2
    ).pop('returns'))
    assert np.allclose(returns_multidim, returns[:, np.newaxis])
    # test nstep = 10
    returns = to_numpy(BasePolicy.compute_nstep_return(
        batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns'))
    assert np.allclose(returns, [
        3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12])
    r_ = compute_nstep_return_base(10, .1, buf, indice)
    assert np.allclose(returns, r_)
    returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
        batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10
    ).pop('returns'))
    assert np.allclose(returns_multidim, returns[:, np.newaxis])

    if __name__ == '__main__':
        buf = ReplayBuffer(size)
        for i in range(int(size * 1.5)):
            buf.add(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0)
        batch, indice = buf.sample(256)

        def vanilla():
            return compute_nstep_return_base(3, .1, buf, indice)

        def optimized():
            return BasePolicy.compute_nstep_return(
                batch, buf, indice, target_q_fn, gamma=.1, n_step=3)

        cnt = 3000
        print('nstep vanilla', timeit(vanilla, setup=vanilla, number=cnt))
        print('nstep optim  ', timeit(optimized, setup=optimized, number=cnt))
예제 #5
0
def test_nstep_returns():
    buf = ReplayBuffer(10)
    for i in range(12):
        buf.add(obs=0, act=0, rew=i + 1, done=i % 4 == 3)
    batch, indice = buf.sample(0)
    assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
    # rew:  [10, 11, 2, 3, 4, 5, 6, 7, 8, 9]
    # done: [ 0,  1, 0, 1, 0, 0, 0, 1, 0, 0]
    # test nstep = 1
    returns = BasePolicy.compute_nstep_return(batch,
                                              buf,
                                              indice,
                                              target_q_fn,
                                              gamma=.1,
                                              n_step=1).pop('returns')
    assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
    # test nstep = 2
    returns = BasePolicy.compute_nstep_return(batch,
                                              buf,
                                              indice,
                                              target_q_fn,
                                              gamma=.1,
                                              n_step=2).pop('returns')
    assert np.allclose(returns,
                       [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
    # test nstep = 10
    returns = BasePolicy.compute_nstep_return(batch,
                                              buf,
                                              indice,
                                              target_q_fn,
                                              gamma=.1,
                                              n_step=10).pop('returns')
    assert np.allclose(returns,
                       [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12])
예제 #6
0
def test_replaybuffer(size=10, bufsize=20):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize)
    buf2 = ReplayBuffer(bufsize)
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(obs, a, rew, done, obs_next, info)
        obs = obs_next
        assert len(buf) == min(bufsize, i + 1)
    data, indice = buf.sample(bufsize * 2)
    assert (indice < len(buf)).all()
    assert (data.obs < size).all()
    assert (0 <= data.done).all() and (data.done <= 1).all()
    assert len(buf) > len(buf2)
    buf2.update(buf)
    assert len(buf) == len(buf2)
    assert buf2[0].obs == buf[5].obs
    assert buf2[-1].obs == buf[4].obs
    b = ReplayBuffer(size=10)
    b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
    assert b.obs[0] == 1
    assert b.done[0] == 'str'
    assert np.all(b.obs[1:] == 0)
    assert np.all(b.done[1:] == np.array(None))
    assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
    assert np.all(b.info.a[1:] == 0)
    assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
    assert np.all(b.info.b.c[1:] == 0.0)
예제 #7
0
def test_nstep_returns(size=10000):
    buf = ReplayBuffer(10)
    for i in range(12):
        buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3))
    batch, indices = buf.sample(0)
    assert np.allclose(indices, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
    # rew:  [11, 12, 3, 4, 5, 6, 7, 8, 9, 10]
    # done: [ 0,  1, 0, 1, 0, 0, 0, 1, 0, 0]
    # test nstep = 1
    returns = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn, gamma=.1, n_step=1
        ).pop('returns').reshape(-1)
    )
    assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
    r_ = compute_nstep_return_base(1, .1, buf, indices)
    assert np.allclose(returns, r_), (r_, returns)
    returns_multidim = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1
        ).pop('returns')
    )
    assert np.allclose(returns_multidim, returns[:, np.newaxis])
    # test nstep = 2
    returns = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn, gamma=.1, n_step=2
        ).pop('returns').reshape(-1)
    )
    assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
    r_ = compute_nstep_return_base(2, .1, buf, indices)
    assert np.allclose(returns, r_)
    returns_multidim = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2
        ).pop('returns')
    )
    assert np.allclose(returns_multidim, returns[:, np.newaxis])
    # test nstep = 10
    returns = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn, gamma=.1, n_step=10
        ).pop('returns').reshape(-1)
    )
    assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12])
    r_ = compute_nstep_return_base(10, .1, buf, indices)
    assert np.allclose(returns, r_)
    returns_multidim = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10
        ).pop('returns')
    )
    assert np.allclose(returns_multidim, returns[:, np.newaxis])
예제 #8
0
    def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs):
        """Update the policy network and replay buffer (if needed). It includes
        three function steps: process_fn, learn, and post_process_fn.

        :param int batch_size: 0 means it will extract all the data from the
            buffer, otherwise it will sample a batch with the given batch_size.
        :param ReplayBuffer buffer: the corresponding replay buffer.
        """
        batch, indice = buffer.sample(batch_size)
        batch = self.process_fn(batch, buffer, indice)
        result = self.learn(batch, *args, **kwargs)
        self.post_process_fn(batch, buffer, indice)
        return result
예제 #9
0
def test_stack(size=5, bufsize=9, stack_num=4):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize, stack_num=stack_num)
    buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
    obs = env.reset(1)
    for i in range(15):
        obs_next, rew, done, info = env.step(1)
        buf.add(obs, 1, rew, done, None, info)
        buf2.add(obs, 1, rew, done, None, info)
        obs = obs_next
        if done:
            obs = env.reset(1)
    indice = np.arange(len(buf))
    assert np.allclose(buf.get(indice, 'obs'), np.array([
        [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
        [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
        [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]]))
    print(buf)
    _, indice = buf2.sample(0)
    assert indice == [2]
    _, indice = buf2.sample(1)
    assert indice.sum() == 2
예제 #10
0
def test_stack(size=5, bufsize=9, stack_num=4):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize, stack_num=stack_num)
    buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
    obs = env.reset(1)
    for i in range(16):
        obs_next, rew, done, info = env.step(1)
        buf.add(obs, 1, rew, done, None, info)
        buf2.add(obs, 1, rew, done, None, info)
        obs = obs_next
        if done:
            obs = env.reset(1)
    indice = np.arange(len(buf))
    assert np.allclose(
        buf.get(indice, 'obs'),
        np.expand_dims([[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1],
                        [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4],
                        [1, 1, 1, 1]],
                       axis=-1))
    _, indice = buf2.sample(0)
    assert indice.tolist() == [2, 6]
    _, indice = buf2.sample(1)
    assert indice in [2, 6]
예제 #11
0
def test_ReplayBuffer():
    """
    tianshou.data.ReplayBuffer
    buf.add()
    buf.get()
    buf.update()
    buf.sample()
    buf.reset()
    len(buf)
    :return:
    """
    buf1 = ReplayBuffer(size=15)
    for i in range(3):
        buf1.add(obs=i,
                 act=i,
                 rew=i,
                 done=i,
                 obs_next=i + 1,
                 info={},
                 weight=None)
    print(len(buf1))
    print(buf1.obs)
    buf2 = ReplayBuffer(size=10)
    for i in range(15):
        buf2.add(obs=i,
                 act=i,
                 rew=i,
                 done=i,
                 obs_next=i + 1,
                 info={},
                 weight=None)
    print(buf2.obs)
    buf1.update(buf2)
    print(buf1.obs)
    index = [1, 3, 5]
    # key is an obligatory args
    print(buf2.get(index, key='obs'))
    print('--------------------')
    sample_data, indice = buf2.sample(batch_size=4)
    print(sample_data, indice)
    print(sample_data.obs == buf2[indice].obs)
    print('--------------------')
    # buf.reset() only resets the index, not the content.
    print(len(buf2))
    buf2.reset()
    print(len(buf2))
    print(buf2)
    print('--------------------')
예제 #12
0
def test_replaybuffer(size=10, bufsize=20):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize)
    buf2 = ReplayBuffer(bufsize)
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(obs, a, rew, done, obs_next, info)
        obs = obs_next
        assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
    data, indice = buf.sample(bufsize * 2)
    assert (indice < len(buf)).all()
    assert (data.obs < size).all()
    assert (0 <= data.done).all() and (data.done <= 1).all()
    assert len(buf) > len(buf2)
    buf2.update(buf)
    assert len(buf) == len(buf2)
    assert buf2[0].obs == buf[5].obs
    assert buf2[-1].obs == buf[4].obs
예제 #13
0
class SSACPolicy(DDPGPolicy):
    """Implementation of Simulator-based Soft Actor-Critic.
    :param torch.nn.Module actor: the actor network following the rules in
        :class:`~tianshou.policy.BasePolicy`. (s -> logits)
    :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
    :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s,
        a))
    :param torch.optim.Optimizer critic1_optim: the optimizer for the first
        critic network.
    :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s,
        a))
    :param torch.optim.Optimizer critic2_optim: the optimizer for the second
        critic network.
    :param action_range: the action range (minimum, maximum).
    :type action_range: Tuple[float, float]
    :param float tau: param for soft update of the target network, defaults to
        0.005.
    :param float gamma: discount factor, in [0, 1], defaults to 0.99.
    :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
        regularization coefficient, default to 0.2.
        If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
        alpha is automatatically tuned.
    :param bool reward_normalization: normalize the reward to Normal(0, 1),
        defaults to False.
    :param bool ignore_done: ignore the done flag while training the policy,
        defaults to False.
    :param BaseNoise exploration_noise: add a noise to action for exploration,
        defaults to None. This is useful when solving hard-exploration problem.
    :param bool deterministic_eval: whether to use deterministic action (mean
        of Gaussian policy) instead of stochastic action sampled by the policy,
        defaults to True.
    .. seealso::
        Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
        explanation.
    """
    def __init__(
        self,
        actor: torch.nn.Module,
        actor_optim: torch.optim.Optimizer,
        critic1: torch.nn.Module,
        critic1_optim: torch.optim.Optimizer,
        critic2: torch.nn.Module,
        critic2_optim: torch.optim.Optimizer,
        simulator: Optional[torch.nn.Module],
        args,
        action_range: Tuple[float, float],
        tau: float = 0.005,
        gamma: float = 0.99,
        alpha: Union[float, Tuple[float, torch.Tensor,
                                  torch.optim.Optimizer]] = 0.2,
        reward_normalization: bool = False,
        ignore_done: bool = False,
        estimation_step: int = 1,
        exploration_noise: Optional[BaseNoise] = None,
        deterministic_eval: bool = True,
        **kwargs: Any,
    ) -> None:
        super().__init__(None, None, None, None, action_range, tau, gamma,
                         exploration_noise, reward_normalization, ignore_done,
                         estimation_step, **kwargs)
        if simulator is not None:
            self.simulator = simulator
        self.args = args
        self.simulation_env = None
        self.loss_history = []
        self.gbm_model = None
        self.update_step = self.args.max_update_step
        self.simulator_buffer = ReplayBuffer(size=self.args.buffer_size)

        self.actor, self.actor_optim = actor, actor_optim
        self.critic1, self.critic1_old = critic1, deepcopy(critic1)
        self.critic1_old.eval()
        self.critic1_optim = critic1_optim
        self.critic2, self.critic2_old = critic2, deepcopy(critic2)
        self.critic2_old.eval()
        self.critic2_optim = critic2_optim
        self.start_simulation = False

        self._is_auto_alpha = False
        self._alpha: Union[float, torch.Tensor]
        if isinstance(alpha, tuple):
            self._is_auto_alpha = True
            self._target_entropy, self._log_alpha, self._alpha_optim = alpha
            assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
            self._alpha = self._log_alpha.detach().exp()
        else:
            self._alpha = alpha

        self._deterministic_eval = deterministic_eval
        self.__eps = np.finfo(np.float32).eps.item()

    def train(self, mode: bool = True) -> "SACPolicy":
        self.training = mode
        self.actor.train(mode)
        self.critic1.train(mode)
        self.critic2.train(mode)
        return self

    def sync_weight(self) -> None:
        for o, n in zip(self.critic1_old.parameters(),
                        self.critic1.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
        for o, n in zip(self.critic2_old.parameters(),
                        self.critic2.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)

    def forward(  # type: ignore
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        obs = batch[input]
        logits, h = self.actor(obs, state=state, info=batch.info)
        assert isinstance(logits, tuple)
        dist = Independent(Normal(*logits), 1)
        if self._deterministic_eval and not self.training:
            x = logits[0]
        else:
            x = dist.rsample()
        y = torch.tanh(x)
        act = y * self._action_scale + self._action_bias
        y = self._action_scale * (1 - y.pow(2)) + self.__eps
        log_prob = dist.log_prob(x).unsqueeze(-1)
        log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
        if self._noise is not None and self.training and not self.updating:
            act += to_torch_as(self._noise(act.shape), act)
        act = act.clamp(self._range[0], self._range[1])
        return Batch(logits=logits,
                     act=act,
                     state=h,
                     dist=dist,
                     log_prob=log_prob)

    def _target_q(self, buffer: ReplayBuffer,
                  indice: np.ndarray) -> torch.Tensor:
        batch = buffer[indice]  # batch.obs: s_{t+n}
        with torch.no_grad():
            obs_next_result = self(batch, input='obs_next')
            a_ = obs_next_result.act
            batch.act = to_torch_as(batch.act, a_)
            target_q = torch.min(
                self.critic1_old(batch.obs_next, a_),
                self.critic2_old(batch.obs_next, a_),
            ) - self._alpha * obs_next_result.log_prob
        return target_q

    def learn_batch(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        weight = batch.pop("weight", 1.0)

        # critic 1
        current_q1 = self.critic1(batch.obs, batch.act).flatten()
        target_q = batch.returns.flatten()
        td1 = current_q1 - target_q
        critic1_loss = (td1.pow(2) * weight).mean()
        # critic1_loss = F.mse_loss(current_q1, target_q)
        self.critic1_optim.zero_grad()
        critic1_loss.backward()
        self.critic1_optim.step()

        # critic 2
        current_q2 = self.critic2(batch.obs, batch.act).flatten()
        td2 = current_q2 - target_q
        critic2_loss = (td2.pow(2) * weight).mean()
        # critic2_loss = F.mse_loss(current_q2, target_q)
        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        # actor
        obs_result = self(batch)
        a = obs_result.act
        current_q1a = self.critic1(batch.obs, a).flatten()
        current_q2a = self.critic2(batch.obs, a).flatten()
        actor_loss = (self._alpha * obs_result.log_prob.flatten() -
                      torch.min(current_q1a, current_q2a)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_prob = obs_result.log_prob.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_prob).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

        self.sync_weight()

        result = {
            "la": actor_loss.item(),
            "lc": (critic1_loss.item() + critic2_loss.item()) / 2.0,
        }
        if self._is_auto_alpha:
            result["lal"] = alpha_loss.item()
            result["a"] = self._alpha.item()  # type: ignore

        return result

    def get_loss_batch(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        with torch.no_grad():
            weight = batch.pop("weight", 1.0)

            # critic 1
            current_q1 = self.critic1(batch.obs, batch.act).flatten()
            target_q = batch.returns.flatten()
            td1 = current_q1 - target_q
            critic1_loss = (td1.pow(2) * weight).mean()
            # critic1_loss = F.mse_loss(current_q1, target_q)

            # critic 2
            current_q2 = self.critic2(batch.obs, batch.act).flatten()
            td2 = current_q2 - target_q
            critic2_loss = (td2.pow(2) * weight).mean()
            # critic2_loss = F.mse_loss(current_q2, target_q)
            # batch.weight = (td1 + td2) / 2.0  # prio-buffer

            # actor
            obs_result = self(batch)
            a = obs_result.act
            current_q1a = self.critic1(batch.obs, a).flatten()
            current_q2a = self.critic2(batch.obs, a).flatten()
            actor_loss = (self._alpha * obs_result.log_prob.flatten() -
                          torch.min(current_q1a, current_q2a)).mean()

            if self._is_auto_alpha:
                log_prob = obs_result.log_prob.detach() + self._target_entropy
                alpha_loss = -(self._log_alpha * log_prob).mean()
                self._alpha_optim.zero_grad()
                alpha_loss.backward()
                self._alpha_optim.step()
                self._alpha = self._log_alpha.detach().exp()

            # self.sync_weight()

            result = {
                "la": actor_loss.item(),
                "lc": (critic1_loss.item() + critic2_loss.item()) / 2.0,
            }
        if self._is_auto_alpha:
            result["lal"] = alpha_loss.item()
            result["a"] = self._alpha.item()  # type: ignore

        return result

    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        if self.update_step > 0:
            self.update_step -= 1
            batch.obs += self.args.noise_obs * np.random.randn(
                *np.shape(batch.obs))
            batch.rew += self.args.noise_rew * np.random.randn(
                *np.shape(batch.rew))
            simulator_loss = self.learn_simulator(batch)
            result = self.learn_batch(batch)
            result["lt"] = simulator_loss[0]
            result["lr"] = simulator_loss[1]
            # result["m"] = self.simulator.m
            # result["l"] = self.simulator.l
            # result["g"] = self.simulator.g
            # result["dt"] = self.simulator.dt
            self.loss_history.append([
                simulator_loss[0], simulator_loss[1], result["la"],
                result["lc"], 0, 0
            ])
        else:
            if not self.start_simulation:
                kwargs['writer'].add_scalar('simulator/start_step',
                                            kwargs['env_step'],
                                            global_step=kwargs['env_step'])
                self.start_simulation = True
            result = self.get_loss_batch(batch)
            if kwargs[
                    'i'] == 0 or self.simulator_buffer._size < self.args.batch_size:
                self.simulate_environment()
            simulation_batch, indice = self.simulator_buffer.sample(
                self.args.batch_size)
            simulation_batch = self.process_fn(simulation_batch,
                                               self.simulator_buffer, indice)
            simulator_result = self.learn_batch(simulation_batch)
            self.post_process_fn(simulation_batch, self.simulator_buffer,
                                 indice)
            result["la2"] = simulator_result["la"]
            result["lc2"] = simulator_result["lc"]
            self.loss_history.append([
                0, 0, result["la"], result["lc"], result["la2"], result["lc2"]
            ])
        return result

    def simulate_environment(self):
        self.simulation_env = SimulationEnv(self.args, self.simulator)
        obs, act, rew, done, info = [], [], [], [], []
        obs.append(self.simulation_env.reset())
        for i in range(self.args.n_simulator_step):
            with torch.no_grad():
                act.append(self(Batch(obs=obs[-1], info={})).act.cpu().numpy())
            result = self.simulation_env.step(act[-1])
            obs.append(result[0])
            rew.append(result[1])
            done.append(result[2])
            info.append(result[3])
        obs_next = np.array(obs[1:])
        obs = np.array(obs[:-1])
        act = np.array(act)
        rew = np.array(rew)
        done = np.array(done)
        for j in range(obs.shape[1]):
            for i in range(self.args.n_simulator_step):
                self.simulator_buffer.add(obs[i, j], act[i, j], rew[i, j],
                                          done[i, j], obs_next[i, j])
        return None

    def learn_simulator(self, batch: Batch):
        target_obs, target_rew = torch.tensor(
            batch.obs_next).float(), torch.tensor(batch.rew).float()
        target_obs = target_obs.to(self.args.device)
        target_rew = target_rew.to(self.args.device)
        targets = [target_obs, target_rew]
        losses = self.simulator(batch.obs,
                                batch.act,
                                white_box=self.args.white_box,
                                train=True,
                                targets=targets,
                                step=self.update_step)
        return losses[0], losses[1]
예제 #14
0
class SDDPGPolicy(BasePolicy):
    """Implementation of Simulator-based Deep Deterministic Policy Gradient.
    We combine DDPG with a model-based simulator.

    :param torch.nn.Module actor: the actor network following the rules in
        :class:`~tianshou.policy.BasePolicy`. (s -> logits)
    :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
    :param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
    :param torch.optim.Optimizer critic_optim: the optimizer for critic
        network.
    :param torch.nn.Module simulator: the simulator network for the environment.
    :param argparse.Namespace args: the arguments.
    :param action_range: the action range (minimum, maximum).
    :type action_range: Tuple[float, float]
    :param float tau: param for soft update of the target network, defaults to
        0.005.
    :param float gamma: discount factor, in [0, 1], defaults to 0.99.
    :param BaseNoise exploration_noise: the exploration noise,
        add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
    :param bool reward_normalization: normalize the reward to Normal(0, 1),
        defaults to False.
    :param bool ignore_done: ignore the done flag while training the policy,
        defaults to False.
    :param int estimation_step: greater than 1, the number of steps to look
        ahead.

    .. seealso::

        Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
        explanation.
    """
    def __init__(
        self,
        actor: Optional[torch.nn.Module],
        actor_optim: Optional[torch.optim.Optimizer],
        critic: Optional[torch.nn.Module],
        critic_optim: Optional[torch.optim.Optimizer],
        simulator: Optional[torch.nn.Module],
        args,
        action_range: Tuple[float, float],
        tau: float = 0.005,
        gamma: float = 0.99,
        exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
        reward_normalization: bool = False,
        ignore_done: bool = False,
        estimation_step: int = 1,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        if actor is not None and actor_optim is not None:
            self.actor: torch.nn.Module = actor
            self.actor_old = deepcopy(actor)
            self.actor_old.eval()
            self.actor_optim: torch.optim.Optimizer = actor_optim
        if critic is not None and critic_optim is not None:
            self.critic: torch.nn.Module = critic
            self.critic_old = deepcopy(critic)
            self.critic_old.eval()
            self.critic_optim: torch.optim.Optimizer = critic_optim
        if simulator is not None:
            self.simulator = simulator
        self.args = args
        self.simulation_env = None
        self.simulator_loss_threshold = self.args.simulator_loss_threshold
        self.base_env = gym.make(args.task)
        assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
        self._tau = tau
        assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
        self._gamma = gamma
        self._noise = exploration_noise
        self._range = action_range
        self._action_bias = (action_range[0] + action_range[1]) / 2.0
        self._action_scale = (action_range[1] - action_range[0]) / 2.0
        # it is only a little difference to use GaussianNoise
        # self.noise = OUNoise()
        self._rm_done = ignore_done
        self._rew_norm = reward_normalization
        assert estimation_step > 0, "estimation_step should be greater than 0"
        self._n_step = estimation_step
        self.loss_history = []
        self.gbm_model = None
        self.update_step = self.args.max_update_step
        self.simulator_buffer = ReplayBuffer(size=self.args.buffer_size)

    def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
        """Set the exploration noise."""
        self._noise = noise

    def train(self, mode: bool = True) -> "DDPGPolicy":
        """Set the module in training mode, except for the target network."""
        self.training = mode
        self.actor.train(mode)
        self.critic.train(mode)
        self.simulator.train(mode)
        return self

    def sync_weight(self) -> None:
        """Soft-update the weight for the target network."""
        for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
        for o, n in zip(self.critic_old.parameters(),
                        self.critic.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)

    def _target_q(self, buffer: ReplayBuffer,
                  indice: np.ndarray) -> torch.Tensor:
        batch = buffer[indice]  # batch.obs_next: s_{t+n}
        with torch.no_grad():
            target_q = self.critic_old(
                batch.obs_next,
                self(batch, model='actor_old', input='obs_next').act)
        return target_q

    def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                   indice: np.ndarray) -> Batch:
        if self._rm_done:
            batch.done = batch.done * 0.0
        batch = self.compute_nstep_return(batch, buffer, indice,
                                          self._target_q, self._gamma,
                                          self._n_step, self._rew_norm)
        return batch

    def forward(
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        model: str = "actor",
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        """Compute action over the given batch data.

        :return: A :class:`~tianshou.data.Batch` which has 2 keys:

            * ``act`` the action.
            * ``state`` the hidden state.

        .. seealso::

            Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
            more detailed explanation.
        """
        model = getattr(self, model)
        obs = batch[input]
        actions, h = model(obs, state=state, info=batch.info)
        actions += self._action_bias
        if self._noise and not self.updating:
            actions += to_torch_as(self._noise(actions.shape), actions)
        actions = actions.clamp(self._range[0], self._range[1])
        return Batch(act=actions, state=h)

    def learn_batch(self, batch: Batch) -> Dict[str, float]:
        weight = batch.pop("weight", 1.0)
        current_q = self.critic(batch.obs, batch.act).flatten()
        target_q = batch.returns.flatten()
        td = current_q - target_q
        critic_loss = (td.pow(2) * weight).mean()
        batch.weight = td  # prio-buffer
        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()
        action = self(batch).act
        actor_loss = -self.critic(batch.obs, action).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()
        self.sync_weight()
        return {
            "la": actor_loss.item(),
            "lc": critic_loss.item(),
        }

    def get_loss_batch(self, batch: Batch) -> Dict[str, float]:
        weight = batch.pop("weight", 1.0)
        with torch.no_grad():
            current_q = self.critic(batch.obs, batch.act).flatten()
            target_q = batch.returns.flatten()
            td = current_q - target_q
            critic_loss = (td.pow(2) * weight).mean()
            action = self(batch).act
            actor_loss = -self.critic(batch.obs, action).mean()
        return {
            "la": actor_loss.item(),
            "lc": critic_loss.item(),
        }

    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        if self.update_step > 0:
            self.update_step -= 1
            simulator_loss = self.learn_simulator(batch)
            result = self.learn_batch(batch)
            result["lt"] = simulator_loss[0]
            result["lr"] = simulator_loss[1]
            # result["m"] = self.simulator.m
            # result["l"] = self.simulator.l
            # result["g"] = self.simulator.g
            # result["dt"] = self.simulator.dt
            self.loss_history.append([
                simulator_loss[0], simulator_loss[1], result["la"],
                result["lc"], 0, 0
            ])
        else:
            result = self.get_loss_batch(batch)
            if kwargs[
                    'i'] == 0 or self.simulator_buffer._size < self.args.batch_size:
                self.simulate_environment()
            simulation_batch, indice = self.simulator_buffer.sample(
                self.args.batch_size)
            simulation_batch = self.process_fn(simulation_batch,
                                               self.simulator_buffer, indice)
            simulator_result = self.learn_batch(simulation_batch)
            self.post_process_fn(simulation_batch, self.simulator_buffer,
                                 indice)
            result["la2"] = simulator_result["la"]
            result["lc2"] = simulator_result["lc"]
            self.loss_history.append([
                0, 0, result["la"], result["lc"], result["la2"], result["lc2"]
            ])
        return result

    def simulate_environment(self):
        self.simulation_env = SimulationEnv(self.args, self.simulator)
        obs, act, rew, done, info = [], [], [], [], []
        obs.append(self.simulation_env.reset())
        for i in range(self.args.n_simulator_step):
            with torch.no_grad():
                act.append(self(Batch(obs=obs[-1], info={})).act.cpu().numpy())
            result = self.simulation_env.step(act[-1])
            obs.append(result[0])
            rew.append(result[1])
            done.append(result[2])
            info.append(result[3])
        obs_next = np.array(obs[1:])
        obs = np.array(obs[:-1])
        act = np.array(act)
        rew = np.array(rew)
        done = np.array(done)
        # obs = obs.reshape(-1, obs.shape[-1])
        # act = act.reshape(-1, act.shape[-1])
        # rew = np.array(rew).reshape(-1)
        # done = np.array(done).reshape(-1)
        # obs_next = obs_next.reshape(-1, obs_next.shape[-1])
        # rew = rew.reshape(obs.shape[0], obs.shape[1])
        for j in range(obs.shape[1]):
            for i in range(self.args.n_simulator_step):
                self.simulator_buffer.add(obs[i, j], act[i, j], rew[i, j],
                                          done[i, j], obs_next[i, j])
        return None

    def learn_simulator(self, batch: Batch):
        target_obs, target_rew = torch.tensor(
            batch.obs_next).float(), torch.tensor(batch.rew).float()
        target_obs = target_obs.to(self.args.device)
        target_rew = target_rew.to(self.args.device)
        targets = [target_obs, target_rew]
        losses = self.simulator(batch.obs,
                                batch.act,
                                white_box=self.args.white_box,
                                train=True,
                                targets=targets,
                                step=self.update_step)
        return losses[0], losses[1]
예제 #15
0
class Collector(object):
    """docstring for Collector"""
    def __init__(self, policy, env, buffer=None, stat_size=100):
        super().__init__()
        self.env = env
        self.env_num = 1
        self.collect_step = 0
        self.collect_episode = 0
        self.collect_time = 0
        if buffer is None:
            self.buffer = ReplayBuffer(100)
        else:
            self.buffer = buffer
        self.policy = policy
        self.process_fn = policy.process_fn
        self._multi_env = isinstance(env, BaseVectorEnv)
        self._multi_buf = False  # True if buf is a list
        # need multiple cache buffers only if storing in one buffer
        self._cached_buf = []
        if self._multi_env:
            self.env_num = len(env)
            if isinstance(self.buffer, list):
                assert len(self.buffer) == self.env_num, \
                    'The number of data buffer does not match the number of ' \
                    'input env.'
                self._multi_buf = True
            elif isinstance(self.buffer, ReplayBuffer):
                self._cached_buf = [
                    ListReplayBuffer() for _ in range(self.env_num)
                ]
            else:
                raise TypeError('The buffer in data collector is invalid!')
        self.reset_env()
        self.reset_buffer()
        # state over batch is either a list, an np.ndarray, or a torch.Tensor
        self.state = None
        self.step_speed = MovAvg(stat_size)
        self.episode_speed = MovAvg(stat_size)

    def reset_buffer(self):
        if self._multi_buf:
            for b in self.buffer:
                b.reset()
        else:
            self.buffer.reset()

    def get_env_num(self):
        return self.env_num

    def reset_env(self):
        self._obs = self.env.reset()
        self._act = self._rew = self._done = self._info = None
        if self._multi_env:
            self.reward = np.zeros(self.env_num)
            self.length = np.zeros(self.env_num)
        else:
            self.reward, self.length = 0, 0
        for b in self._cached_buf:
            b.reset()

    def seed(self, seed=None):
        if hasattr(self.env, 'seed'):
            return self.env.seed(seed)

    def render(self, **kwargs):
        if hasattr(self.env, 'render'):
            return self.env.render(**kwargs)

    def close(self):
        if hasattr(self.env, 'close'):
            self.env.close()

    def _make_batch(self, data):
        if isinstance(data, np.ndarray):
            return data[None]
        else:
            return np.array([data])

    def collect(self, n_step=0, n_episode=0, render=0):
        warning_count = 0
        if not self._multi_env:
            n_episode = np.sum(n_episode)
        start_time = time.time()
        assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
            "One and only one collection number specification permitted!"
        cur_step = 0
        cur_episode = np.zeros(self.env_num) if self._multi_env else 0
        reward_sum = 0
        length_sum = 0
        while True:
            if warning_count >= 100000:
                warnings.warn(
                    'There are already many steps in an episode. '
                    'You should add a time limitation to your environment!',
                    Warning)
            if self._multi_env:
                batch_data = Batch(obs=self._obs,
                                   act=self._act,
                                   rew=self._rew,
                                   done=self._done,
                                   obs_next=None,
                                   info=self._info)
            else:
                batch_data = Batch(obs=self._make_batch(self._obs),
                                   act=self._make_batch(self._act),
                                   rew=self._make_batch(self._rew),
                                   done=self._make_batch(self._done),
                                   obs_next=None,
                                   info=self._make_batch(self._info))
            result = self.policy(batch_data, self.state)
            self.state = result.state if hasattr(result, 'state') else None
            if isinstance(result.act, torch.Tensor):
                self._act = result.act.detach().cpu().numpy()
            elif not isinstance(self._act, np.ndarray):
                self._act = np.array(result.act)
            else:
                self._act = result.act
            obs_next, self._rew, self._done, self._info = self.env.step(
                self._act if self._multi_env else self._act[0])
            if render > 0:
                self.env.render()
                time.sleep(render)
            self.length += 1
            self.reward += self._rew
            if self._multi_env:
                for i in range(self.env_num):
                    data = {
                        'obs': self._obs[i],
                        'act': self._act[i],
                        'rew': self._rew[i],
                        'done': self._done[i],
                        'obs_next': obs_next[i],
                        'info': self._info[i]
                    }
                    if self._cached_buf:
                        warning_count += 1
                        self._cached_buf[i].add(**data)
                    elif self._multi_buf:
                        warning_count += 1
                        self.buffer[i].add(**data)
                        cur_step += 1
                    else:
                        warning_count += 1
                        self.buffer.add(**data)
                        cur_step += 1
                    if self._done[i]:
                        if n_step != 0 or np.isscalar(n_episode) or \
                                cur_episode[i] < n_episode[i]:
                            cur_episode[i] += 1
                            reward_sum += self.reward[i]
                            length_sum += self.length[i]
                            if self._cached_buf:
                                cur_step += len(self._cached_buf[i])
                                self.buffer.update(self._cached_buf[i])
                        self.reward[i], self.length[i] = 0, 0
                        if self._cached_buf:
                            self._cached_buf[i].reset()
                        if isinstance(self.state, list):
                            self.state[i] = None
                        elif self.state is not None:
                            if isinstance(self.state[i], dict):
                                self.state[i] = {}
                            else:
                                self.state[i] = self.state[i] * 0
                            if isinstance(self.state, torch.Tensor):
                                # remove ref count in pytorch (?)
                                self.state = self.state.detach()
                if sum(self._done):
                    obs_next = self.env.reset(np.where(self._done)[0])
                if n_episode != 0:
                    if isinstance(n_episode, list) and \
                            (cur_episode >= np.array(n_episode)).all() or \
                            np.isscalar(n_episode) and \
                            cur_episode.sum() >= n_episode:
                        break
            else:
                self.buffer.add(self._obs, self._act[0], self._rew, self._done,
                                obs_next, self._info)
                cur_step += 1
                if self._done:
                    cur_episode += 1
                    reward_sum += self.reward
                    length_sum += self.length
                    self.reward, self.length = 0, 0
                    self.state = None
                    obs_next = self.env.reset()
                if n_episode != 0 and cur_episode >= n_episode:
                    break
            if n_step != 0 and cur_step >= n_step:
                break
            self._obs = obs_next
        self._obs = obs_next
        if self._multi_env:
            cur_episode = sum(cur_episode)
        duration = time.time() - start_time
        self.step_speed.add(cur_step / duration)
        self.episode_speed.add(cur_episode / duration)
        self.collect_step += cur_step
        self.collect_episode += cur_episode
        self.collect_time += duration
        if isinstance(n_episode, list):
            n_episode = np.sum(n_episode)
        else:
            n_episode = max(cur_episode, 1)
        return {
            'n/ep': cur_episode,
            'n/st': cur_step,
            'v/st': self.step_speed.get(),
            'v/ep': self.episode_speed.get(),
            'rew': reward_sum / n_episode,
            'len': length_sum / n_episode,
        }

    def sample(self, batch_size):
        if self._multi_buf:
            if batch_size > 0:
                lens = [len(b) for b in self.buffer]
                total = sum(lens)
                batch_index = np.random.choice(total,
                                               batch_size,
                                               p=np.array(lens) / total)
            else:
                batch_index = np.array([])
            batch_data = Batch()
            for i, b in enumerate(self.buffer):
                cur_batch = (batch_index == i).sum()
                if batch_size and cur_batch or batch_size <= 0:
                    batch, indice = b.sample(cur_batch)
                    batch = self.process_fn(batch, b, indice)
                    batch_data.append(batch)
        else:
            batch_data, indice = self.buffer.sample(batch_size)
            batch_data = self.process_fn(batch_data, self.buffer, indice)
        return batch_data
예제 #16
0
class Collector(object):
    """The :class:`~tianshou.data.Collector` enables the policy to interact
    with different types of environments conveniently.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param env: an environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
        class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
        ``None``, it will automatically assign a small-size
        :class:`~tianshou.data.ReplayBuffer`.
    :param int stat_size: for the moving average of recording speed, defaults
        to 100.

    Example:
    ::

        policy = PGPolicy(...)  # or other policies if you wish
        env = gym.make('CartPole-v0')
        replay_buffer = ReplayBuffer(size=10000)
        # here we set up a collector with a single environment
        collector = Collector(policy, env, buffer=replay_buffer)

        # the collector supports vectorized environments as well
        envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
        buffers = [ReplayBuffer(size=5000) for _ in range(3)]
        # you can also pass a list of replay buffer to collector, for multi-env
        # collector = Collector(policy, envs, buffer=buffers)
        collector = Collector(policy, envs, buffer=replay_buffer)

        # collect at least 3 episodes
        collector.collect(n_episode=3)
        # collect 1 episode for the first env, 3 for the third env
        collector.collect(n_episode=[1, 0, 3])
        # collect at least 2 steps
        collector.collect(n_step=2)
        # collect episodes with visual rendering (the render argument is the
        #   sleep time between rendering consecutive frames)
        collector.collect(n_episode=1, render=0.03)

        # sample data with a given number of batch-size:
        batch_data = collector.sample(batch_size=64)
        # policy.learn(batch_data)  # btw, vanilla policy gradient only
        #   supports on-policy training, so here we pick all data in the buffer
        batch_data = collector.sample(batch_size=0)
        policy.learn(batch_data)
        # on-policy algorithms use the collected data only once, so here we
        #   clear the buffer
        collector.reset_buffer()

    For the scenario of collecting data from multiple environments to a single
    buffer, the cache buffers will turn on automatically. It may return the
    data more than the given limitation.

    .. note::

        Please make sure the given environment has a time limitation.
    """

    def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs):
        super().__init__()
        self.env = env
        self.env_num = 1
        self.collect_step = 0
        self.collect_episode = 0
        self.collect_time = 0
        if buffer is None:
            self.buffer = ReplayBuffer(100)
        else:
            self.buffer = buffer
        self.policy = policy
        self.process_fn = policy.process_fn
        self._multi_env = isinstance(env, BaseVectorEnv)
        self._multi_buf = False  # True if buf is a list
        # need multiple cache buffers only if storing in one buffer
        self._cached_buf = []
        if self._multi_env:
            self.env_num = len(env)
            if isinstance(self.buffer, list):
                assert len(self.buffer) == self.env_num, \
                    'The number of data buffer does not match the number of ' \
                    'input env.'
                self._multi_buf = True
            elif isinstance(self.buffer, ReplayBuffer):
                self._cached_buf = [
                    ListReplayBuffer() for _ in range(self.env_num)]
            else:
                raise TypeError('The buffer in data collector is invalid!')
        self.reset_env()
        self.reset_buffer()
        # state over batch is either a list, an np.ndarray, or a torch.Tensor
        self.state = None
        self.step_speed = MovAvg(stat_size)
        self.episode_speed = MovAvg(stat_size)

    def reset_buffer(self):
        """Reset the main data buffer."""
        if self._multi_buf:
            for b in self.buffer:
                b.reset()
        else:
            self.buffer.reset()

    def get_env_num(self):
        """Return the number of environments the collector has."""
        return self.env_num

    def reset_env(self):
        """Reset all of the environment(s)' states and reset all of the cache
        buffers (if need).
        """
        self._obs = self.env.reset()
        self._act = self._rew = self._done = self._info = None
        if self._multi_env:
            self.reward = np.zeros(self.env_num)
            self.length = np.zeros(self.env_num)
        else:
            self.reward, self.length = 0, 0
        for b in self._cached_buf:
            b.reset()

    def seed(self, seed=None):
        """Reset all the seed(s) of the given environment(s)."""
        if hasattr(self.env, 'seed'):
            return self.env.seed(seed)

    def render(self, **kwargs):
        """Render all the environment(s)."""
        if hasattr(self.env, 'render'):
            return self.env.render(**kwargs)

    def close(self):
        """Close the environment(s)."""
        if hasattr(self.env, 'close'):
            self.env.close()

    def _make_batch(self, data):
        """Return [data]."""
        if isinstance(data, np.ndarray):
            return data[None]
        else:
            return np.array([data])

    def _reset_state(self, id):
        """Reset self.state[id]."""
        if self.state is None:
            return
        if isinstance(self.state, list):
            self.state[id] = None
        elif isinstance(self.state, dict):
            for k in self.state:
                if isinstance(self.state[k], list):
                    self.state[k][id] = None
                elif isinstance(self.state[k], torch.Tensor) or \
                        isinstance(self.state[k], np.ndarray):
                    self.state[k][id] = 0
        elif isinstance(self.state, torch.Tensor) or \
                isinstance(self.state, np.ndarray):
            self.state[id] = 0

    def collect(self, n_step=0, n_episode=0, render=None):
        """Collect a specified number of step or episode.

        :param int n_step: how many steps you want to collect.
        :param n_episode: how many episodes you want to collect (in each
            environment).
        :type n_episode: int or list
        :param float render: the sleep time between rendering consecutive
            frames, defaults to ``None`` (no rendering).

        .. note::

            One and only one collection number specification is permitted,
            either ``n_step`` or ``n_episode``.

        :return: A dict including the following keys

            * ``n/ep`` the collected number of episodes.
            * ``n/st`` the collected number of steps.
            * ``v/st`` the speed of steps per second.
            * ``v/ep`` the speed of episode per second.
            * ``rew`` the mean reward over collected episodes.
            * ``len`` the mean length over collected episodes.
        """
        warning_count = 0
        if not self._multi_env:
            n_episode = np.sum(n_episode)
        start_time = time.time()
        assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
            "One and only one collection number specification is permitted!"
        cur_step = 0
        cur_episode = np.zeros(self.env_num) if self._multi_env else 0
        reward_sum = 0
        length_sum = 0
        while True:
            if warning_count >= 100000:
                warnings.warn(
                    'There are already many steps in an episode. '
                    'You should add a time limitation to your environment!',
                    Warning)
            if self._multi_env:
                batch_data = Batch(
                    obs=self._obs, act=self._act, rew=self._rew,
                    done=self._done, obs_next=None, info=self._info)
            else:
                batch_data = Batch(
                    obs=self._make_batch(self._obs),
                    act=self._make_batch(self._act),
                    rew=self._make_batch(self._rew),
                    done=self._make_batch(self._done),
                    obs_next=None,
                    info=self._make_batch(self._info))
            with torch.no_grad():
                result = self.policy(batch_data, self.state)
            self.state = result.state if hasattr(result, 'state') else None
            if isinstance(result.act, torch.Tensor):
                self._act = result.act.detach().cpu().numpy()
            elif not isinstance(self._act, np.ndarray):
                self._act = np.array(result.act)
            else:
                self._act = result.act
            obs_next, self._rew, self._done, self._info = self.env.step(
                self._act if self._multi_env else self._act[0])
            if render is not None:
                self.env.render()
                if render > 0:
                    time.sleep(render)
            self.length += 1
            self.reward += self._rew
            if self._multi_env:
                for i in range(self.env_num):
                    data = {
                        'obs': self._obs[i], 'act': self._act[i],
                        'rew': self._rew[i], 'done': self._done[i],
                        'obs_next': obs_next[i], 'info': self._info[i]}
                    if self._cached_buf:
                        warning_count += 1
                        self._cached_buf[i].add(**data)
                    elif self._multi_buf:
                        warning_count += 1
                        self.buffer[i].add(**data)
                        cur_step += 1
                    else:
                        warning_count += 1
                        self.buffer.add(**data)
                        cur_step += 1
                    if self._done[i]:
                        if n_step != 0 or np.isscalar(n_episode) or \
                                cur_episode[i] < n_episode[i]:
                            cur_episode[i] += 1
                            reward_sum += self.reward[i]
                            length_sum += self.length[i]
                            if self._cached_buf:
                                cur_step += len(self._cached_buf[i])
                                self.buffer.update(self._cached_buf[i])
                        self.reward[i], self.length[i] = 0, 0
                        if self._cached_buf:
                            self._cached_buf[i].reset()
                        self._reset_state(i)
                if sum(self._done):
                    obs_next = self.env.reset(np.where(self._done)[0])
                if n_episode != 0:
                    if isinstance(n_episode, list) and \
                            (cur_episode >= np.array(n_episode)).all() or \
                            np.isscalar(n_episode) and \
                            cur_episode.sum() >= n_episode:
                        break
            else:
                self.buffer.add(
                    self._obs, self._act[0], self._rew,
                    self._done, obs_next, self._info)
                cur_step += 1
                if self._done:
                    cur_episode += 1
                    reward_sum += self.reward
                    length_sum += self.length
                    self.reward, self.length = 0, 0
                    self.state = None
                    obs_next = self.env.reset()
                if n_episode != 0 and cur_episode >= n_episode:
                    break
            if n_step != 0 and cur_step >= n_step:
                break
            self._obs = obs_next
        self._obs = obs_next
        if self._multi_env:
            cur_episode = sum(cur_episode)
        duration = max(time.time() - start_time, 1e-9)
        self.step_speed.add(cur_step / duration)
        self.episode_speed.add(cur_episode / duration)
        self.collect_step += cur_step
        self.collect_episode += cur_episode
        self.collect_time += duration
        if isinstance(n_episode, list):
            n_episode = np.sum(n_episode)
        else:
            n_episode = max(cur_episode, 1)
        return {
            'n/ep': cur_episode,
            'n/st': cur_step,
            'v/st': self.step_speed.get(),
            'v/ep': self.episode_speed.get(),
            'rew': reward_sum / n_episode,
            'len': length_sum / n_episode,
        }

    def sample(self, batch_size):
        """Sample a data batch from the internal replay buffer. It will call
        :meth:`~tianshou.policy.BasePolicy.process_fn` before returning
        the final batch data.

        :param int batch_size: ``0`` means it will extract all the data from
            the buffer, otherwise it will extract the data with the given
            batch_size.
        """
        if self._multi_buf:
            if batch_size > 0:
                lens = [len(b) for b in self.buffer]
                total = sum(lens)
                batch_index = np.random.choice(
                    total, batch_size, p=np.array(lens) / total)
            else:
                batch_index = np.array([])
            batch_data = Batch()
            for i, b in enumerate(self.buffer):
                cur_batch = (batch_index == i).sum()
                if batch_size and cur_batch or batch_size <= 0:
                    batch, indice = b.sample(cur_batch)
                    batch = self.process_fn(batch, b, indice)
                    batch_data.append(batch)
        else:
            batch_data, indice = self.buffer.sample(batch_size)
            batch_data = self.process_fn(batch_data, self.buffer, indice)
        return batch_data
예제 #17
0
파일: MAPPO.py 프로젝트: Luckych454/MAFRL
def test_Fedppo(args=get_args()):
    torch.set_num_threads(1)  # for poor CPU
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # train_envs = gym.make(args.task)
    # you can also use tianshou.env.SubprocVectorEnv
    # train_envs = DummyVectorEnv(
    #     [lambda: gym.make(args.task) for _ in range(args.training_num)])
    # # test_envs = gym.make(args.task)
    # test_envs = DummyVectorEnv(
    #     [lambda: gym.make(args.task) for _ in range(args.test_num)])
    if args.data_quantity != 0:
        env.set_data_quantity(args.data_quantity)
    if args.data_quality != 0:
        env.set_data_quality(args.data_quality)
    if args.psi != 0:
        env.set_psi(args.psi)
    if args.nu != 0:
        env.set_nu(args.nu)
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # train_envs.seed(args.seed)
    # test_envs.seed(args.seed)
    # model

    # server policy
    server_policy = build_policy(0, args)

    # client policy
    ND_policy = build_policy(1, args)
    RD_policy = build_policy(2, args)
    FD_policy = build_policy(3, args)
    # 不用collector,用replaybuffer
    server_buffer = ReplayBuffer(args.buffer_size)
    ND_buffer = ReplayBuffer(args.buffer_size)
    RD_buffer = ReplayBuffer(args.buffer_size)
    FD_buffer = ReplayBuffer(args.buffer_size)

    # log
    log_path = os.path.join(args.logdir, args.task, 'ppo')
    writer = SummaryWriter(log_path)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    # 这里开始我自己写,自己写trainer和testor
    # 为了查看server额收敛情况,我们首先不训练client网络。。。
    start_time = time.time()
    _server_obs, _ND_obs, _RD_obs, _FD_obs = env.reset()
    _server_act = _server_rew = _done = _info = None
    server_buffer.reset()
    _ND_act = _ND_rew = _RD_act = _RD_rew = _FD_act = _FD_rew = [None]
    ND_buffer.reset()
    RD_buffer.reset()
    FD_buffer.reset()
    all_server_costs = []
    all_ND_utility = []
    all_RD_utility = []
    all_FD_utility = []
    all_leak_probability = []
    for epoch in range(1, 1 + args.epoch):
        # 每个epoch收集N*T数据,然后用B训练M次
        server_costs = []
        ND_utility = []
        FD_utility = []
        RD_utility = []
        leak_probability = []
        payment = []
        expected_time = []
        training_time = []
        with tqdm.tqdm(total=args.step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                # 收集数据,不用梯度
                # server
                _server_obs, _ND_obs, _RD_obs, _FD_obs = env.reset()
                server_batch = Batch(obs=_server_obs,
                                     act=_server_act,
                                     rew=_server_rew,
                                     done=_done,
                                     obs_next=None,
                                     info=_info,
                                     policy=None)
                with torch.no_grad():
                    server_result = server_policy(server_batch, None)
                _server_policy = [{}]
                _server_act = to_numpy(server_result.act)
                # ND
                ND_batch = Batch(obs=_ND_obs,
                                 act=_ND_act,
                                 rew=_ND_rew,
                                 done=_done,
                                 obs_next=None,
                                 info=_info,
                                 policy=None)
                with torch.no_grad():
                    ND_result = ND_policy(ND_batch, None)
                _ND_policy = [{}]
                _ND_act = to_numpy(ND_result.act)
                # RD
                RD_batch = Batch(obs=_RD_obs,
                                 act=_RD_act,
                                 rew=_RD_rew,
                                 done=_done,
                                 obs_next=None,
                                 info=_info,
                                 policy=None)
                with torch.no_grad():
                    RD_result = RD_policy(RD_batch, None)
                _RD_policy = [{}]
                _RD_act = to_numpy(RD_result.act)
                # FD
                FD_batch = Batch(obs=_FD_obs,
                                 act=_FD_act,
                                 rew=_FD_rew,
                                 done=_done,
                                 obs_next=None,
                                 info=_info,
                                 policy=None)
                with torch.no_grad():
                    FD_result = FD_policy(FD_batch, None)
                _FD_policy = [{}]
                _FD_act = to_numpy(FD_result.act)
                # print(_ND_act.shape)
                server_obs_next, ND_obs_next, RD_obs_next, FD_obs_next, _server_rew, _client_rew, _done, _info = env.step(
                    _server_act[0], _ND_act[0], _RD_act[0], _FD_act[0])

                server_costs.append(_server_rew)
                ND_utility.append(_client_rew[0])
                RD_utility.append(_client_rew[1])
                FD_utility.append(_client_rew[2])
                leak_probability.append(_info[0]["leak"])
                payment.append(env.payment)
                expected_time.append(env.expected_time)
                training_time.append(env.global_time * env.time_lambda)
                # 加入replay buffer
                server_buffer.add(
                    Batch(obs=_server_obs[0],
                          act=_server_act[0],
                          rew=_server_rew[0],
                          done=_done[0],
                          obs_next=server_obs_next[0],
                          info=_info[0],
                          policy=_server_policy[0]))
                ND_buffer.add(
                    Batch(obs=_ND_obs[0],
                          act=_ND_act[0],
                          rew=_client_rew[0],
                          done=_done[0],
                          obs_next=ND_obs_next[0],
                          info=_info[0],
                          policy=_ND_policy[0]))
                RD_buffer.add(
                    Batch(obs=_RD_obs[0],
                          act=_RD_act[0],
                          rew=_client_rew[1],
                          done=_done[0],
                          obs_next=RD_obs_next[0],
                          info=_info[0],
                          policy=_RD_policy[0]))
                FD_buffer.add(
                    Batch(obs=_FD_obs[0],
                          act=_FD_act[0],
                          rew=_client_rew[2],
                          done=_done[0],
                          obs_next=FD_obs_next[0],
                          info=_info[0],
                          policy=_FD_policy[0]))
                t.update(1)
                _server_obs = server_obs_next
                _ND_obs = ND_obs_next
                _RD_obs = RD_obs_next
                _FD_obs = FD_obs_next
        all_server_costs.append(np.array(server_costs).mean())
        all_ND_utility.append(np.array(ND_utility).mean())
        all_RD_utility.append(np.array(RD_utility).mean())
        all_FD_utility.append(np.array(FD_utility).mean())
        all_leak_probability.append(np.array(leak_probability).mean())
        print("current bandwidth:", env.bandwidth)
        print("leak signal:", env.leak_NU, env.leak_FU)
        print("current server cost:", np.array(server_costs).mean())
        print("current device utility:", all_ND_utility[-1],
              all_RD_utility[-1], all_FD_utility[-1])
        print("leak probability:", all_leak_probability[-1])
        print("server_act:", _server_act[0])
        print("device_acts:", _ND_act[0], _RD_act[0], _FD_act[0])
        print("payment cost:", np.array(payment).mean())
        print("Expected time cost:", np.array(expected_time).mean())
        print("Training time cost:", np.array(training_time).mean())
        # print("server_act:",_server_act)
        # print("client_act:",_client_act)
        print("info:", env.communication_time, env.computation_time,
              env.K_theta)
        server_batch_data, server_indice = server_buffer.sample(0)
        server_batch_data = server_policy.process_fn(server_batch_data,
                                                     server_buffer,
                                                     server_indice)
        server_policy.learn(server_batch_data, args.batch_size,
                            args.repeat_per_collect)
        server_buffer.reset()

        ND_batch_data, ND_indice = ND_buffer.sample(0)
        ND_batch_data = ND_policy.process_fn(ND_batch_data, ND_buffer,
                                             ND_indice)
        ND_policy.learn(ND_batch_data, args.batch_size,
                        args.repeat_per_collect)
        ND_buffer.reset()

        RD_batch_data, RD_indice = RD_buffer.sample(0)
        RD_batch_data = RD_policy.process_fn(RD_batch_data, RD_buffer,
                                             RD_indice)
        RD_policy.learn(RD_batch_data, args.batch_size,
                        args.repeat_per_collect)
        RD_buffer.reset()

        FD_batch_data, FD_indice = FD_buffer.sample(0)
        FD_batch_data = FD_policy.process_fn(FD_batch_data, FD_buffer,
                                             FD_indice)
        FD_policy.learn(FD_batch_data, args.batch_size,
                        args.repeat_per_collect)
        FD_buffer.reset()
    print("all_server_cost:", all_server_costs)
    print("all_ND_utility:", all_ND_utility)
    print("all_RD_utility:", all_RD_utility)
    print("all_FD_utility:", all_FD_utility)
    print("all_leak_probability:", all_leak_probability)
    plt.plot(all_server_costs)
    plt.show()
예제 #18
0
def test_replaybuffer(size=10, bufsize=20):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize)
    buf.update(buf)
    assert str(buf) == buf.__class__.__name__ + '()'
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(
            Batch(obs=obs,
                  act=[a],
                  rew=rew,
                  done=done,
                  obs_next=obs_next,
                  info=info))
        obs = obs_next
        assert len(buf) == min(bufsize, i + 1)
    assert buf.act.dtype == int
    assert buf.act.shape == (bufsize, 1)
    data, indices = buf.sample(bufsize * 2)
    assert (indices < len(buf)).all()
    assert (data.obs < size).all()
    assert (0 <= data.done).all() and (data.done <= 1).all()
    b = ReplayBuffer(size=10)
    # neg bsz should return empty index
    assert b.sample_indices(-1).tolist() == []
    ptr, ep_rew, ep_len, ep_idx = b.add(
        Batch(obs=1,
              act=1,
              rew=1,
              done=1,
              obs_next='str',
              info={
                  'a': 3,
                  'b': {
                      'c': 5.0
                  }
              }))
    assert b.obs[0] == 1
    assert b.done[0]
    assert b.obs_next[0] == 'str'
    assert np.all(b.obs[1:] == 0)
    assert np.all(b.obs_next[1:] == np.array(None))
    assert b.info.a[0] == 3 and b.info.a.dtype == int
    assert np.all(b.info.a[1:] == 0)
    assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float
    assert np.all(b.info.b.c[1:] == 0.0)
    assert ptr.shape == (1, ) and ptr[0] == 0
    assert ep_rew.shape == (1, ) and ep_rew[0] == 1
    assert ep_len.shape == (1, ) and ep_len[0] == 1
    assert ep_idx.shape == (1, ) and ep_idx[0] == 0
    # test extra keys pop up, the buffer should handle it dynamically
    batch = Batch(obs=2,
                  act=2,
                  rew=2,
                  done=0,
                  obs_next="str2",
                  info={
                      "a": 4,
                      "d": {
                          "e": -np.inf
                      }
                  })
    b.add(batch)
    info_keys = ["a", "b", "d"]
    assert set(b.info.keys()) == set(info_keys)
    assert b.info.a[1] == 4 and b.info.b.c[1] == 0
    assert b.info.d.e[1] == -np.inf
    # test batch-style adding method, where len(batch) == 1
    batch.done = [1]
    batch.info.e = np.zeros([1, 4])
    batch = Batch.stack([batch])
    ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
    assert ptr.shape == (1, ) and ptr[0] == 2
    assert ep_rew.shape == (1, ) and ep_rew[0] == 4
    assert ep_len.shape == (1, ) and ep_len[0] == 2
    assert ep_idx.shape == (1, ) and ep_idx[0] == 1
    assert set(b.info.keys()) == set(info_keys + ["e"])
    assert b.info.e.shape == (b.maxsize, 1, 4)
    with pytest.raises(IndexError):
        b[22]
    # test prev / next
    assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1])
    assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2])
    batch.done = [0]
    b.add(batch, buffer_ids=[0])
    assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3])
    assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])