def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) assert str(buf) == buf.__class__.__name__ + '()' obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) buf.add(obs, [a], rew, done, obs_next, info) obs = obs_next assert len(buf) == min(bufsize, i + 1) with pytest.raises(ValueError): buf._add_to_buffer('rew', np.array([1, 2, 3])) assert buf.act.dtype == np.object assert isinstance(buf.act[0], list) data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() b = ReplayBuffer(size=10) b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}}) assert b.obs[0] == 1 assert b.done[0] == 'str' assert np.all(b.obs[1:] == 0) assert np.all(b.done[1:] == np.array(None)) assert b.info.a[0] == 3 and b.info.a.dtype == np.integer assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact assert np.all(b.info.b.c[1:] == 0.0) with pytest.raises(IndexError): b[22] b = ListReplayBuffer() with pytest.raises(NotImplementedError): b.sample(0)
def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) obs = env.reset(1) for _ in range(16): obs_next, rew, done, info = env.step(1) buf.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) buf2.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) buf3.add( Batch(obs=[obs, obs, obs], act=1, rew=rew, done=done, obs_next=[obs, obs], info=info)) obs = obs_next if done: obs = env.reset(1) indices = np.arange(len(buf)) assert np.allclose( buf.get(indices, 'obs')[..., 0], [[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs')) assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs_next')) _, indices = buf2.sample(0) assert indices.tolist() == [2, 6] _, indices = buf2.sample(1) assert indices[0] in [2, 6] batch, indices = buf2.sample(-1) # neg bsz -> no data assert indices.tolist() == [] and len(batch) == 0 with pytest.raises(IndexError): buf[bufsize * 2]
def test_stack(size=5, bufsize=9, stack_num=4): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) obs = env.reset(1) for i in range(16): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) buf2.add(obs, 1, rew, done, None, info) buf3.add([None, None, obs], 1, rew, done, [None, obs], info) obs = obs_next if done: obs = env.reset(1) indice = np.arange(len(buf)) assert np.allclose(buf.get(indice, 'obs')[..., 0], [ [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs')) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next')) _, indice = buf2.sample(0) assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) assert indice in [2, 6] with pytest.raises(IndexError): buf[bufsize * 2]
def test_nstep_returns(size=10000): buf = ReplayBuffer(10) for i in range(12): buf.add(obs=0, act=0, rew=i + 1, done=i % 4 == 3) batch, indice = buf.sample(0) assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [10, 11, 2, 3, 4, 5, 6, 7, 8, 9] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=1).pop('returns')) assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indice) assert np.allclose(returns, r_), (r_, returns) returns_multidim = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1 ).pop('returns')) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns')) assert np.allclose(returns, [ 3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indice) assert np.allclose(returns, r_) returns_multidim = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2 ).pop('returns')) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns')) assert np.allclose(returns, [ 3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, .1, buf, indice) assert np.allclose(returns, r_) returns_multidim = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10 ).pop('returns')) assert np.allclose(returns_multidim, returns[:, np.newaxis]) if __name__ == '__main__': buf = ReplayBuffer(size) for i in range(int(size * 1.5)): buf.add(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0) batch, indice = buf.sample(256) def vanilla(): return compute_nstep_return_base(3, .1, buf, indice) def optimized(): return BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=3) cnt = 3000 print('nstep vanilla', timeit(vanilla, setup=vanilla, number=cnt)) print('nstep optim ', timeit(optimized, setup=optimized, number=cnt))
def test_nstep_returns(): buf = ReplayBuffer(10) for i in range(12): buf.add(obs=0, act=0, rew=i + 1, done=i % 4 == 3) batch, indice = buf.sample(0) assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [10, 11, 2, 3, 4, 5, 6, 7, 8, 9] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = BasePolicy.compute_nstep_return(batch, buf, indice, target_q_fn, gamma=.1, n_step=1).pop('returns') assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) # test nstep = 2 returns = BasePolicy.compute_nstep_return(batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns') assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) # test nstep = 10 returns = BasePolicy.compute_nstep_return(batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns') assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12])
def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) buf2 = ReplayBuffer(bufsize) obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) buf.add(obs, a, rew, done, obs_next, info) obs = obs_next assert len(buf) == min(bufsize, i + 1) data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() assert len(buf) > len(buf2) buf2.update(buf) assert len(buf) == len(buf2) assert buf2[0].obs == buf[5].obs assert buf2[-1].obs == buf[4].obs b = ReplayBuffer(size=10) b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}}) assert b.obs[0] == 1 assert b.done[0] == 'str' assert np.all(b.obs[1:] == 0) assert np.all(b.done[1:] == np.array(None)) assert b.info.a[0] == 3 and b.info.a.dtype == np.integer assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact assert np.all(b.info.b.c[1:] == 0.0)
def test_nstep_returns(size=10000): buf = ReplayBuffer(10) for i in range(12): buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3)) batch, indices = buf.sample(0) assert np.allclose(indices, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn, gamma=.1, n_step=1 ).pop('returns').reshape(-1) ) assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indices) assert np.allclose(returns, r_), (r_, returns) returns_multidim = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1 ).pop('returns') ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn, gamma=.1, n_step=2 ).pop('returns').reshape(-1) ) assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2 ).pop('returns') ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn, gamma=.1, n_step=10 ).pop('returns').reshape(-1) ) assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, .1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10 ).pop('returns') ) assert np.allclose(returns_multidim, returns[:, np.newaxis])
def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs): """Update the policy network and replay buffer (if needed). It includes three function steps: process_fn, learn, and post_process_fn. :param int batch_size: 0 means it will extract all the data from the buffer, otherwise it will sample a batch with the given batch_size. :param ReplayBuffer buffer: the corresponding replay buffer. """ batch, indice = buffer.sample(batch_size) batch = self.process_fn(batch, buffer, indice) result = self.learn(batch, *args, **kwargs) self.post_process_fn(batch, buffer, indice) return result
def test_stack(size=5, bufsize=9, stack_num=4): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) obs = env.reset(1) for i in range(15): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) buf2.add(obs, 1, rew, done, None, info) obs = obs_next if done: obs = env.reset(1) indice = np.arange(len(buf)) assert np.allclose(buf.get(indice, 'obs'), np.array([ [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])) print(buf) _, indice = buf2.sample(0) assert indice == [2] _, indice = buf2.sample(1) assert indice.sum() == 2
def test_stack(size=5, bufsize=9, stack_num=4): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) obs = env.reset(1) for i in range(16): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) buf2.add(obs, 1, rew, done, None, info) obs = obs_next if done: obs = env.reset(1) indice = np.arange(len(buf)) assert np.allclose( buf.get(indice, 'obs'), np.expand_dims([[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]], axis=-1)) _, indice = buf2.sample(0) assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) assert indice in [2, 6]
def test_ReplayBuffer(): """ tianshou.data.ReplayBuffer buf.add() buf.get() buf.update() buf.sample() buf.reset() len(buf) :return: """ buf1 = ReplayBuffer(size=15) for i in range(3): buf1.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}, weight=None) print(len(buf1)) print(buf1.obs) buf2 = ReplayBuffer(size=10) for i in range(15): buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}, weight=None) print(buf2.obs) buf1.update(buf2) print(buf1.obs) index = [1, 3, 5] # key is an obligatory args print(buf2.get(index, key='obs')) print('--------------------') sample_data, indice = buf2.sample(batch_size=4) print(sample_data, indice) print(sample_data.obs == buf2[indice].obs) print('--------------------') # buf.reset() only resets the index, not the content. print(len(buf2)) buf2.reset() print(len(buf2)) print(buf2) print('--------------------')
def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) buf2 = ReplayBuffer(bufsize) obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) buf.add(obs, a, rew, done, obs_next, info) obs = obs_next assert len(buf) == min(bufsize, i + 1), print(len(buf), i) data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() assert len(buf) > len(buf2) buf2.update(buf) assert len(buf) == len(buf2) assert buf2[0].obs == buf[5].obs assert buf2[-1].obs == buf[4].obs
class SSACPolicy(DDPGPolicy): """Implementation of Simulator-based Soft Actor-Critic. :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.optim.Optimizer actor_optim: the optimizer for actor network. :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic1_optim: the optimizer for the first critic network. :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic2_optim: the optimizer for the second critic network. :param action_range: the action range (minimum, maximum). :type action_range: Tuple[float, float] :param float tau: param for soft update of the target network, defaults to 0.005. :param float gamma: discount factor, in [0, 1], defaults to 0.99. :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy regularization coefficient, default to 0.2. If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then alpha is automatatically tuned. :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to False. :param bool ignore_done: ignore the done flag while training the policy, defaults to False. :param BaseNoise exploration_noise: add a noise to action for exploration, defaults to None. This is useful when solving hard-exploration problem. :param bool deterministic_eval: whether to use deterministic action (mean of Gaussian policy) instead of stochastic action sampled by the policy, defaults to True. .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ def __init__( self, actor: torch.nn.Module, actor_optim: torch.optim.Optimizer, critic1: torch.nn.Module, critic1_optim: torch.optim.Optimizer, critic2: torch.nn.Module, critic2_optim: torch.optim.Optimizer, simulator: Optional[torch.nn.Module], args, action_range: Tuple[float, float], tau: float = 0.005, gamma: float = 0.99, alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, reward_normalization: bool = False, ignore_done: bool = False, estimation_step: int = 1, exploration_noise: Optional[BaseNoise] = None, deterministic_eval: bool = True, **kwargs: Any, ) -> None: super().__init__(None, None, None, None, action_range, tau, gamma, exploration_noise, reward_normalization, ignore_done, estimation_step, **kwargs) if simulator is not None: self.simulator = simulator self.args = args self.simulation_env = None self.loss_history = [] self.gbm_model = None self.update_step = self.args.max_update_step self.simulator_buffer = ReplayBuffer(size=self.args.buffer_size) self.actor, self.actor_optim = actor, actor_optim self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() self.critic1_optim = critic1_optim self.critic2, self.critic2_old = critic2, deepcopy(critic2) self.critic2_old.eval() self.critic2_optim = critic2_optim self.start_simulation = False self._is_auto_alpha = False self._alpha: Union[float, torch.Tensor] if isinstance(alpha, tuple): self._is_auto_alpha = True self._target_entropy, self._log_alpha, self._alpha_optim = alpha assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad self._alpha = self._log_alpha.detach().exp() else: self._alpha = alpha self._deterministic_eval = deterministic_eval self.__eps = np.finfo(np.float32).eps.item() def train(self, mode: bool = True) -> "SACPolicy": self.training = mode self.actor.train(mode) self.critic1.train(mode) self.critic2.train(mode) return self def sync_weight(self) -> None: for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: x = logits[0] else: x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) if self._noise is not None and self.training and not self.updating: act += to_torch_as(self._noise(act.shape), act) act = act.clamp(self._range[0], self._range[1]) return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} with torch.no_grad(): obs_next_result = self(batch, input='obs_next') a_ = obs_next_result.act batch.act = to_torch_as(batch.act, a_) target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_), ) - self._alpha * obs_next_result.log_prob return target_q def learn_batch(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor obs_result = self(batch) a = obs_result.act current_q1a = self.critic1(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, a).flatten() actor_loss = (self._alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if self._is_auto_alpha: log_prob = obs_result.log_prob.detach() + self._target_entropy alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.detach().exp() self.sync_weight() result = { "la": actor_loss.item(), "lc": (critic1_loss.item() + critic2_loss.item()) / 2.0, } if self._is_auto_alpha: result["lal"] = alpha_loss.item() result["a"] = self._alpha.item() # type: ignore return result def get_loss_batch(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: with torch.no_grad(): weight = batch.pop("weight", 1.0) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) # batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor obs_result = self(batch) a = obs_result.act current_q1a = self.critic1(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, a).flatten() actor_loss = (self._alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a)).mean() if self._is_auto_alpha: log_prob = obs_result.log_prob.detach() + self._target_entropy alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.detach().exp() # self.sync_weight() result = { "la": actor_loss.item(), "lc": (critic1_loss.item() + critic2_loss.item()) / 2.0, } if self._is_auto_alpha: result["lal"] = alpha_loss.item() result["a"] = self._alpha.item() # type: ignore return result def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self.update_step > 0: self.update_step -= 1 batch.obs += self.args.noise_obs * np.random.randn( *np.shape(batch.obs)) batch.rew += self.args.noise_rew * np.random.randn( *np.shape(batch.rew)) simulator_loss = self.learn_simulator(batch) result = self.learn_batch(batch) result["lt"] = simulator_loss[0] result["lr"] = simulator_loss[1] # result["m"] = self.simulator.m # result["l"] = self.simulator.l # result["g"] = self.simulator.g # result["dt"] = self.simulator.dt self.loss_history.append([ simulator_loss[0], simulator_loss[1], result["la"], result["lc"], 0, 0 ]) else: if not self.start_simulation: kwargs['writer'].add_scalar('simulator/start_step', kwargs['env_step'], global_step=kwargs['env_step']) self.start_simulation = True result = self.get_loss_batch(batch) if kwargs[ 'i'] == 0 or self.simulator_buffer._size < self.args.batch_size: self.simulate_environment() simulation_batch, indice = self.simulator_buffer.sample( self.args.batch_size) simulation_batch = self.process_fn(simulation_batch, self.simulator_buffer, indice) simulator_result = self.learn_batch(simulation_batch) self.post_process_fn(simulation_batch, self.simulator_buffer, indice) result["la2"] = simulator_result["la"] result["lc2"] = simulator_result["lc"] self.loss_history.append([ 0, 0, result["la"], result["lc"], result["la2"], result["lc2"] ]) return result def simulate_environment(self): self.simulation_env = SimulationEnv(self.args, self.simulator) obs, act, rew, done, info = [], [], [], [], [] obs.append(self.simulation_env.reset()) for i in range(self.args.n_simulator_step): with torch.no_grad(): act.append(self(Batch(obs=obs[-1], info={})).act.cpu().numpy()) result = self.simulation_env.step(act[-1]) obs.append(result[0]) rew.append(result[1]) done.append(result[2]) info.append(result[3]) obs_next = np.array(obs[1:]) obs = np.array(obs[:-1]) act = np.array(act) rew = np.array(rew) done = np.array(done) for j in range(obs.shape[1]): for i in range(self.args.n_simulator_step): self.simulator_buffer.add(obs[i, j], act[i, j], rew[i, j], done[i, j], obs_next[i, j]) return None def learn_simulator(self, batch: Batch): target_obs, target_rew = torch.tensor( batch.obs_next).float(), torch.tensor(batch.rew).float() target_obs = target_obs.to(self.args.device) target_rew = target_rew.to(self.args.device) targets = [target_obs, target_rew] losses = self.simulator(batch.obs, batch.act, white_box=self.args.white_box, train=True, targets=targets, step=self.update_step) return losses[0], losses[1]
class SDDPGPolicy(BasePolicy): """Implementation of Simulator-based Deep Deterministic Policy Gradient. We combine DDPG with a model-based simulator. :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.optim.Optimizer actor_optim: the optimizer for actor network. :param torch.nn.Module critic: the critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic_optim: the optimizer for critic network. :param torch.nn.Module simulator: the simulator network for the environment. :param argparse.Namespace args: the arguments. :param action_range: the action range (minimum, maximum). :type action_range: Tuple[float, float] :param float tau: param for soft update of the target network, defaults to 0.005. :param float gamma: discount factor, in [0, 1], defaults to 0.99. :param BaseNoise exploration_noise: the exploration noise, add to the action, defaults to ``GaussianNoise(sigma=0.1)``. :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to False. :param bool ignore_done: ignore the done flag while training the policy, defaults to False. :param int estimation_step: greater than 1, the number of steps to look ahead. .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ def __init__( self, actor: Optional[torch.nn.Module], actor_optim: Optional[torch.optim.Optimizer], critic: Optional[torch.nn.Module], critic_optim: Optional[torch.optim.Optimizer], simulator: Optional[torch.nn.Module], args, action_range: Tuple[float, float], tau: float = 0.005, gamma: float = 0.99, exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), reward_normalization: bool = False, ignore_done: bool = False, estimation_step: int = 1, **kwargs: Any, ) -> None: super().__init__(**kwargs) if actor is not None and actor_optim is not None: self.actor: torch.nn.Module = actor self.actor_old = deepcopy(actor) self.actor_old.eval() self.actor_optim: torch.optim.Optimizer = actor_optim if critic is not None and critic_optim is not None: self.critic: torch.nn.Module = critic self.critic_old = deepcopy(critic) self.critic_old.eval() self.critic_optim: torch.optim.Optimizer = critic_optim if simulator is not None: self.simulator = simulator self.args = args self.simulation_env = None self.simulator_loss_threshold = self.args.simulator_loss_threshold self.base_env = gym.make(args.task) assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]" self._tau = tau assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]" self._gamma = gamma self._noise = exploration_noise self._range = action_range self._action_bias = (action_range[0] + action_range[1]) / 2.0 self._action_scale = (action_range[1] - action_range[0]) / 2.0 # it is only a little difference to use GaussianNoise # self.noise = OUNoise() self._rm_done = ignore_done self._rew_norm = reward_normalization assert estimation_step > 0, "estimation_step should be greater than 0" self._n_step = estimation_step self.loss_history = [] self.gbm_model = None self.update_step = self.args.max_update_step self.simulator_buffer = ReplayBuffer(size=self.args.buffer_size) def set_exp_noise(self, noise: Optional[BaseNoise]) -> None: """Set the exploration noise.""" self._noise = noise def train(self, mode: bool = True) -> "DDPGPolicy": """Set the module in training mode, except for the target network.""" self.training = mode self.actor.train(mode) self.critic.train(mode) self.simulator.train(mode) return self def sync_weight(self) -> None: """Soft-update the weight for the target network.""" for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) for o, n in zip(self.critic_old.parameters(), self.critic.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} with torch.no_grad(): target_q = self.critic_old( batch.obs_next, self(batch, model='actor_old', input='obs_next').act) return target_q def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._rm_done: batch.done = batch.done * 0.0 batch = self.compute_nstep_return(batch, buffer, indice, self._target_q, self._gamma, self._n_step, self._rew_norm) return batch def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = "actor", input: str = "obs", **kwargs: Any, ) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 2 keys: * ``act`` the action. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) obs = batch[input] actions, h = model(obs, state=state, info=batch.info) actions += self._action_bias if self._noise and not self.updating: actions += to_torch_as(self._noise(actions.shape), actions) actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h) def learn_batch(self, batch: Batch) -> Dict[str, float]: weight = batch.pop("weight", 1.0) current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() action = self(batch).act actor_loss = -self.critic(batch.obs, action).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { "la": actor_loss.item(), "lc": critic_loss.item(), } def get_loss_batch(self, batch: Batch) -> Dict[str, float]: weight = batch.pop("weight", 1.0) with torch.no_grad(): current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * weight).mean() action = self(batch).act actor_loss = -self.critic(batch.obs, action).mean() return { "la": actor_loss.item(), "lc": critic_loss.item(), } def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self.update_step > 0: self.update_step -= 1 simulator_loss = self.learn_simulator(batch) result = self.learn_batch(batch) result["lt"] = simulator_loss[0] result["lr"] = simulator_loss[1] # result["m"] = self.simulator.m # result["l"] = self.simulator.l # result["g"] = self.simulator.g # result["dt"] = self.simulator.dt self.loss_history.append([ simulator_loss[0], simulator_loss[1], result["la"], result["lc"], 0, 0 ]) else: result = self.get_loss_batch(batch) if kwargs[ 'i'] == 0 or self.simulator_buffer._size < self.args.batch_size: self.simulate_environment() simulation_batch, indice = self.simulator_buffer.sample( self.args.batch_size) simulation_batch = self.process_fn(simulation_batch, self.simulator_buffer, indice) simulator_result = self.learn_batch(simulation_batch) self.post_process_fn(simulation_batch, self.simulator_buffer, indice) result["la2"] = simulator_result["la"] result["lc2"] = simulator_result["lc"] self.loss_history.append([ 0, 0, result["la"], result["lc"], result["la2"], result["lc2"] ]) return result def simulate_environment(self): self.simulation_env = SimulationEnv(self.args, self.simulator) obs, act, rew, done, info = [], [], [], [], [] obs.append(self.simulation_env.reset()) for i in range(self.args.n_simulator_step): with torch.no_grad(): act.append(self(Batch(obs=obs[-1], info={})).act.cpu().numpy()) result = self.simulation_env.step(act[-1]) obs.append(result[0]) rew.append(result[1]) done.append(result[2]) info.append(result[3]) obs_next = np.array(obs[1:]) obs = np.array(obs[:-1]) act = np.array(act) rew = np.array(rew) done = np.array(done) # obs = obs.reshape(-1, obs.shape[-1]) # act = act.reshape(-1, act.shape[-1]) # rew = np.array(rew).reshape(-1) # done = np.array(done).reshape(-1) # obs_next = obs_next.reshape(-1, obs_next.shape[-1]) # rew = rew.reshape(obs.shape[0], obs.shape[1]) for j in range(obs.shape[1]): for i in range(self.args.n_simulator_step): self.simulator_buffer.add(obs[i, j], act[i, j], rew[i, j], done[i, j], obs_next[i, j]) return None def learn_simulator(self, batch: Batch): target_obs, target_rew = torch.tensor( batch.obs_next).float(), torch.tensor(batch.rew).float() target_obs = target_obs.to(self.args.device) target_rew = target_rew.to(self.args.device) targets = [target_obs, target_rew] losses = self.simulator(batch.obs, batch.act, white_box=self.args.white_box, train=True, targets=targets, step=self.update_step) return losses[0], losses[1]
class Collector(object): """docstring for Collector""" def __init__(self, policy, env, buffer=None, stat_size=100): super().__init__() self.env = env self.env_num = 1 self.collect_step = 0 self.collect_episode = 0 self.collect_time = 0 if buffer is None: self.buffer = ReplayBuffer(100) else: self.buffer = buffer self.policy = policy self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) self._multi_buf = False # True if buf is a list # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) if isinstance(self.buffer, list): assert len(self.buffer) == self.env_num, \ 'The number of data buffer does not match the number of ' \ 'input env.' self._multi_buf = True elif isinstance(self.buffer, ReplayBuffer): self._cached_buf = [ ListReplayBuffer() for _ in range(self.env_num) ] else: raise TypeError('The buffer in data collector is invalid!') self.reset_env() self.reset_buffer() # state over batch is either a list, an np.ndarray, or a torch.Tensor self.state = None self.step_speed = MovAvg(stat_size) self.episode_speed = MovAvg(stat_size) def reset_buffer(self): if self._multi_buf: for b in self.buffer: b.reset() else: self.buffer.reset() def get_env_num(self): return self.env_num def reset_env(self): self._obs = self.env.reset() self._act = self._rew = self._done = self._info = None if self._multi_env: self.reward = np.zeros(self.env_num) self.length = np.zeros(self.env_num) else: self.reward, self.length = 0, 0 for b in self._cached_buf: b.reset() def seed(self, seed=None): if hasattr(self.env, 'seed'): return self.env.seed(seed) def render(self, **kwargs): if hasattr(self.env, 'render'): return self.env.render(**kwargs) def close(self): if hasattr(self.env, 'close'): self.env.close() def _make_batch(self, data): if isinstance(data, np.ndarray): return data[None] else: return np.array([data]) def collect(self, n_step=0, n_episode=0, render=0): warning_count = 0 if not self._multi_env: n_episode = np.sum(n_episode) start_time = time.time() assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ "One and only one collection number specification permitted!" cur_step = 0 cur_episode = np.zeros(self.env_num) if self._multi_env else 0 reward_sum = 0 length_sum = 0 while True: if warning_count >= 100000: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) if self._multi_env: batch_data = Batch(obs=self._obs, act=self._act, rew=self._rew, done=self._done, obs_next=None, info=self._info) else: batch_data = Batch(obs=self._make_batch(self._obs), act=self._make_batch(self._act), rew=self._make_batch(self._rew), done=self._make_batch(self._done), obs_next=None, info=self._make_batch(self._info)) result = self.policy(batch_data, self.state) self.state = result.state if hasattr(result, 'state') else None if isinstance(result.act, torch.Tensor): self._act = result.act.detach().cpu().numpy() elif not isinstance(self._act, np.ndarray): self._act = np.array(result.act) else: self._act = result.act obs_next, self._rew, self._done, self._info = self.env.step( self._act if self._multi_env else self._act[0]) if render > 0: self.env.render() time.sleep(render) self.length += 1 self.reward += self._rew if self._multi_env: for i in range(self.env_num): data = { 'obs': self._obs[i], 'act': self._act[i], 'rew': self._rew[i], 'done': self._done[i], 'obs_next': obs_next[i], 'info': self._info[i] } if self._cached_buf: warning_count += 1 self._cached_buf[i].add(**data) elif self._multi_buf: warning_count += 1 self.buffer[i].add(**data) cur_step += 1 else: warning_count += 1 self.buffer.add(**data) cur_step += 1 if self._done[i]: if n_step != 0 or np.isscalar(n_episode) or \ cur_episode[i] < n_episode[i]: cur_episode[i] += 1 reward_sum += self.reward[i] length_sum += self.length[i] if self._cached_buf: cur_step += len(self._cached_buf[i]) self.buffer.update(self._cached_buf[i]) self.reward[i], self.length[i] = 0, 0 if self._cached_buf: self._cached_buf[i].reset() if isinstance(self.state, list): self.state[i] = None elif self.state is not None: if isinstance(self.state[i], dict): self.state[i] = {} else: self.state[i] = self.state[i] * 0 if isinstance(self.state, torch.Tensor): # remove ref count in pytorch (?) self.state = self.state.detach() if sum(self._done): obs_next = self.env.reset(np.where(self._done)[0]) if n_episode != 0: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ cur_episode.sum() >= n_episode: break else: self.buffer.add(self._obs, self._act[0], self._rew, self._done, obs_next, self._info) cur_step += 1 if self._done: cur_episode += 1 reward_sum += self.reward length_sum += self.length self.reward, self.length = 0, 0 self.state = None obs_next = self.env.reset() if n_episode != 0 and cur_episode >= n_episode: break if n_step != 0 and cur_step >= n_step: break self._obs = obs_next self._obs = obs_next if self._multi_env: cur_episode = sum(cur_episode) duration = time.time() - start_time self.step_speed.add(cur_step / duration) self.episode_speed.add(cur_episode / duration) self.collect_step += cur_step self.collect_episode += cur_episode self.collect_time += duration if isinstance(n_episode, list): n_episode = np.sum(n_episode) else: n_episode = max(cur_episode, 1) return { 'n/ep': cur_episode, 'n/st': cur_step, 'v/st': self.step_speed.get(), 'v/ep': self.episode_speed.get(), 'rew': reward_sum / n_episode, 'len': length_sum / n_episode, } def sample(self, batch_size): if self._multi_buf: if batch_size > 0: lens = [len(b) for b in self.buffer] total = sum(lens) batch_index = np.random.choice(total, batch_size, p=np.array(lens) / total) else: batch_index = np.array([]) batch_data = Batch() for i, b in enumerate(self.buffer): cur_batch = (batch_index == i).sum() if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) batch_data.append(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data
class Collector(object): """The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: an environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to ``None``, it will automatically assign a small-size :class:`~tianshou.data.ReplayBuffer`. :param int stat_size: for the moving average of recording speed, defaults to 100. Example: :: policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) buffers = [ReplayBuffer(size=5000) for _ in range(3)] # you can also pass a list of replay buffer to collector, for multi-env # collector = Collector(policy, envs, buffer=buffers) collector = Collector(policy, envs, buffer=replay_buffer) # collect at least 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) # sample data with a given number of batch-size: batch_data = collector.sample(batch_size=64) # policy.learn(batch_data) # btw, vanilla policy gradient only # supports on-policy training, so here we pick all data in the buffer batch_data = collector.sample(batch_size=0) policy.learn(batch_data) # on-policy algorithms use the collected data only once, so here we # clear the buffer collector.reset_buffer() For the scenario of collecting data from multiple environments to a single buffer, the cache buffers will turn on automatically. It may return the data more than the given limitation. .. note:: Please make sure the given environment has a time limitation. """ def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs): super().__init__() self.env = env self.env_num = 1 self.collect_step = 0 self.collect_episode = 0 self.collect_time = 0 if buffer is None: self.buffer = ReplayBuffer(100) else: self.buffer = buffer self.policy = policy self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) self._multi_buf = False # True if buf is a list # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) if isinstance(self.buffer, list): assert len(self.buffer) == self.env_num, \ 'The number of data buffer does not match the number of ' \ 'input env.' self._multi_buf = True elif isinstance(self.buffer, ReplayBuffer): self._cached_buf = [ ListReplayBuffer() for _ in range(self.env_num)] else: raise TypeError('The buffer in data collector is invalid!') self.reset_env() self.reset_buffer() # state over batch is either a list, an np.ndarray, or a torch.Tensor self.state = None self.step_speed = MovAvg(stat_size) self.episode_speed = MovAvg(stat_size) def reset_buffer(self): """Reset the main data buffer.""" if self._multi_buf: for b in self.buffer: b.reset() else: self.buffer.reset() def get_env_num(self): """Return the number of environments the collector has.""" return self.env_num def reset_env(self): """Reset all of the environment(s)' states and reset all of the cache buffers (if need). """ self._obs = self.env.reset() self._act = self._rew = self._done = self._info = None if self._multi_env: self.reward = np.zeros(self.env_num) self.length = np.zeros(self.env_num) else: self.reward, self.length = 0, 0 for b in self._cached_buf: b.reset() def seed(self, seed=None): """Reset all the seed(s) of the given environment(s).""" if hasattr(self.env, 'seed'): return self.env.seed(seed) def render(self, **kwargs): """Render all the environment(s).""" if hasattr(self.env, 'render'): return self.env.render(**kwargs) def close(self): """Close the environment(s).""" if hasattr(self.env, 'close'): self.env.close() def _make_batch(self, data): """Return [data].""" if isinstance(data, np.ndarray): return data[None] else: return np.array([data]) def _reset_state(self, id): """Reset self.state[id].""" if self.state is None: return if isinstance(self.state, list): self.state[id] = None elif isinstance(self.state, dict): for k in self.state: if isinstance(self.state[k], list): self.state[k][id] = None elif isinstance(self.state[k], torch.Tensor) or \ isinstance(self.state[k], np.ndarray): self.state[k][id] = 0 elif isinstance(self.state, torch.Tensor) or \ isinstance(self.state, np.ndarray): self.state[id] = 0 def collect(self, n_step=0, n_episode=0, render=None): """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect (in each environment). :type n_episode: int or list :param float render: the sleep time between rendering consecutive frames, defaults to ``None`` (no rendering). .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. * ``v/st`` the speed of steps per second. * ``v/ep`` the speed of episode per second. * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ warning_count = 0 if not self._multi_env: n_episode = np.sum(n_episode) start_time = time.time() assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ "One and only one collection number specification is permitted!" cur_step = 0 cur_episode = np.zeros(self.env_num) if self._multi_env else 0 reward_sum = 0 length_sum = 0 while True: if warning_count >= 100000: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) if self._multi_env: batch_data = Batch( obs=self._obs, act=self._act, rew=self._rew, done=self._done, obs_next=None, info=self._info) else: batch_data = Batch( obs=self._make_batch(self._obs), act=self._make_batch(self._act), rew=self._make_batch(self._rew), done=self._make_batch(self._done), obs_next=None, info=self._make_batch(self._info)) with torch.no_grad(): result = self.policy(batch_data, self.state) self.state = result.state if hasattr(result, 'state') else None if isinstance(result.act, torch.Tensor): self._act = result.act.detach().cpu().numpy() elif not isinstance(self._act, np.ndarray): self._act = np.array(result.act) else: self._act = result.act obs_next, self._rew, self._done, self._info = self.env.step( self._act if self._multi_env else self._act[0]) if render is not None: self.env.render() if render > 0: time.sleep(render) self.length += 1 self.reward += self._rew if self._multi_env: for i in range(self.env_num): data = { 'obs': self._obs[i], 'act': self._act[i], 'rew': self._rew[i], 'done': self._done[i], 'obs_next': obs_next[i], 'info': self._info[i]} if self._cached_buf: warning_count += 1 self._cached_buf[i].add(**data) elif self._multi_buf: warning_count += 1 self.buffer[i].add(**data) cur_step += 1 else: warning_count += 1 self.buffer.add(**data) cur_step += 1 if self._done[i]: if n_step != 0 or np.isscalar(n_episode) or \ cur_episode[i] < n_episode[i]: cur_episode[i] += 1 reward_sum += self.reward[i] length_sum += self.length[i] if self._cached_buf: cur_step += len(self._cached_buf[i]) self.buffer.update(self._cached_buf[i]) self.reward[i], self.length[i] = 0, 0 if self._cached_buf: self._cached_buf[i].reset() self._reset_state(i) if sum(self._done): obs_next = self.env.reset(np.where(self._done)[0]) if n_episode != 0: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ cur_episode.sum() >= n_episode: break else: self.buffer.add( self._obs, self._act[0], self._rew, self._done, obs_next, self._info) cur_step += 1 if self._done: cur_episode += 1 reward_sum += self.reward length_sum += self.length self.reward, self.length = 0, 0 self.state = None obs_next = self.env.reset() if n_episode != 0 and cur_episode >= n_episode: break if n_step != 0 and cur_step >= n_step: break self._obs = obs_next self._obs = obs_next if self._multi_env: cur_episode = sum(cur_episode) duration = max(time.time() - start_time, 1e-9) self.step_speed.add(cur_step / duration) self.episode_speed.add(cur_episode / duration) self.collect_step += cur_step self.collect_episode += cur_episode self.collect_time += duration if isinstance(n_episode, list): n_episode = np.sum(n_episode) else: n_episode = max(cur_episode, 1) return { 'n/ep': cur_episode, 'n/st': cur_step, 'v/st': self.step_speed.get(), 'v/ep': self.episode_speed.get(), 'rew': reward_sum / n_episode, 'len': length_sum / n_episode, } def sample(self, batch_size): """Sample a data batch from the internal replay buffer. It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size. """ if self._multi_buf: if batch_size > 0: lens = [len(b) for b in self.buffer] total = sum(lens) batch_index = np.random.choice( total, batch_size, p=np.array(lens) / total) else: batch_index = np.array([]) batch_data = Batch() for i, b in enumerate(self.buffer): cur_batch = (batch_index == i).sum() if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) batch_data.append(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data
def test_Fedppo(args=get_args()): torch.set_num_threads(1) # for poor CPU env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv # train_envs = DummyVectorEnv( # [lambda: gym.make(args.task) for _ in range(args.training_num)]) # # test_envs = gym.make(args.task) # test_envs = DummyVectorEnv( # [lambda: gym.make(args.task) for _ in range(args.test_num)]) if args.data_quantity != 0: env.set_data_quantity(args.data_quantity) if args.data_quality != 0: env.set_data_quality(args.data_quality) if args.psi != 0: env.set_psi(args.psi) if args.nu != 0: env.set_nu(args.nu) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # train_envs.seed(args.seed) # test_envs.seed(args.seed) # model # server policy server_policy = build_policy(0, args) # client policy ND_policy = build_policy(1, args) RD_policy = build_policy(2, args) FD_policy = build_policy(3, args) # 不用collector,用replaybuffer server_buffer = ReplayBuffer(args.buffer_size) ND_buffer = ReplayBuffer(args.buffer_size) RD_buffer = ReplayBuffer(args.buffer_size) FD_buffer = ReplayBuffer(args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) # 这里开始我自己写,自己写trainer和testor # 为了查看server额收敛情况,我们首先不训练client网络。。。 start_time = time.time() _server_obs, _ND_obs, _RD_obs, _FD_obs = env.reset() _server_act = _server_rew = _done = _info = None server_buffer.reset() _ND_act = _ND_rew = _RD_act = _RD_rew = _FD_act = _FD_rew = [None] ND_buffer.reset() RD_buffer.reset() FD_buffer.reset() all_server_costs = [] all_ND_utility = [] all_RD_utility = [] all_FD_utility = [] all_leak_probability = [] for epoch in range(1, 1 + args.epoch): # 每个epoch收集N*T数据,然后用B训练M次 server_costs = [] ND_utility = [] FD_utility = [] RD_utility = [] leak_probability = [] payment = [] expected_time = [] training_time = [] with tqdm.tqdm(total=args.step_per_epoch, desc=f'Epoch #{epoch}', **tqdm_config) as t: while t.n < t.total: # 收集数据,不用梯度 # server _server_obs, _ND_obs, _RD_obs, _FD_obs = env.reset() server_batch = Batch(obs=_server_obs, act=_server_act, rew=_server_rew, done=_done, obs_next=None, info=_info, policy=None) with torch.no_grad(): server_result = server_policy(server_batch, None) _server_policy = [{}] _server_act = to_numpy(server_result.act) # ND ND_batch = Batch(obs=_ND_obs, act=_ND_act, rew=_ND_rew, done=_done, obs_next=None, info=_info, policy=None) with torch.no_grad(): ND_result = ND_policy(ND_batch, None) _ND_policy = [{}] _ND_act = to_numpy(ND_result.act) # RD RD_batch = Batch(obs=_RD_obs, act=_RD_act, rew=_RD_rew, done=_done, obs_next=None, info=_info, policy=None) with torch.no_grad(): RD_result = RD_policy(RD_batch, None) _RD_policy = [{}] _RD_act = to_numpy(RD_result.act) # FD FD_batch = Batch(obs=_FD_obs, act=_FD_act, rew=_FD_rew, done=_done, obs_next=None, info=_info, policy=None) with torch.no_grad(): FD_result = FD_policy(FD_batch, None) _FD_policy = [{}] _FD_act = to_numpy(FD_result.act) # print(_ND_act.shape) server_obs_next, ND_obs_next, RD_obs_next, FD_obs_next, _server_rew, _client_rew, _done, _info = env.step( _server_act[0], _ND_act[0], _RD_act[0], _FD_act[0]) server_costs.append(_server_rew) ND_utility.append(_client_rew[0]) RD_utility.append(_client_rew[1]) FD_utility.append(_client_rew[2]) leak_probability.append(_info[0]["leak"]) payment.append(env.payment) expected_time.append(env.expected_time) training_time.append(env.global_time * env.time_lambda) # 加入replay buffer server_buffer.add( Batch(obs=_server_obs[0], act=_server_act[0], rew=_server_rew[0], done=_done[0], obs_next=server_obs_next[0], info=_info[0], policy=_server_policy[0])) ND_buffer.add( Batch(obs=_ND_obs[0], act=_ND_act[0], rew=_client_rew[0], done=_done[0], obs_next=ND_obs_next[0], info=_info[0], policy=_ND_policy[0])) RD_buffer.add( Batch(obs=_RD_obs[0], act=_RD_act[0], rew=_client_rew[1], done=_done[0], obs_next=RD_obs_next[0], info=_info[0], policy=_RD_policy[0])) FD_buffer.add( Batch(obs=_FD_obs[0], act=_FD_act[0], rew=_client_rew[2], done=_done[0], obs_next=FD_obs_next[0], info=_info[0], policy=_FD_policy[0])) t.update(1) _server_obs = server_obs_next _ND_obs = ND_obs_next _RD_obs = RD_obs_next _FD_obs = FD_obs_next all_server_costs.append(np.array(server_costs).mean()) all_ND_utility.append(np.array(ND_utility).mean()) all_RD_utility.append(np.array(RD_utility).mean()) all_FD_utility.append(np.array(FD_utility).mean()) all_leak_probability.append(np.array(leak_probability).mean()) print("current bandwidth:", env.bandwidth) print("leak signal:", env.leak_NU, env.leak_FU) print("current server cost:", np.array(server_costs).mean()) print("current device utility:", all_ND_utility[-1], all_RD_utility[-1], all_FD_utility[-1]) print("leak probability:", all_leak_probability[-1]) print("server_act:", _server_act[0]) print("device_acts:", _ND_act[0], _RD_act[0], _FD_act[0]) print("payment cost:", np.array(payment).mean()) print("Expected time cost:", np.array(expected_time).mean()) print("Training time cost:", np.array(training_time).mean()) # print("server_act:",_server_act) # print("client_act:",_client_act) print("info:", env.communication_time, env.computation_time, env.K_theta) server_batch_data, server_indice = server_buffer.sample(0) server_batch_data = server_policy.process_fn(server_batch_data, server_buffer, server_indice) server_policy.learn(server_batch_data, args.batch_size, args.repeat_per_collect) server_buffer.reset() ND_batch_data, ND_indice = ND_buffer.sample(0) ND_batch_data = ND_policy.process_fn(ND_batch_data, ND_buffer, ND_indice) ND_policy.learn(ND_batch_data, args.batch_size, args.repeat_per_collect) ND_buffer.reset() RD_batch_data, RD_indice = RD_buffer.sample(0) RD_batch_data = RD_policy.process_fn(RD_batch_data, RD_buffer, RD_indice) RD_policy.learn(RD_batch_data, args.batch_size, args.repeat_per_collect) RD_buffer.reset() FD_batch_data, FD_indice = FD_buffer.sample(0) FD_batch_data = FD_policy.process_fn(FD_batch_data, FD_buffer, FD_indice) FD_policy.learn(FD_batch_data, args.batch_size, args.repeat_per_collect) FD_buffer.reset() print("all_server_cost:", all_server_costs) print("all_ND_utility:", all_ND_utility) print("all_RD_utility:", all_RD_utility) print("all_FD_utility:", all_FD_utility) print("all_leak_probability:", all_leak_probability) plt.plot(all_server_costs) plt.show()
def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) assert str(buf) == buf.__class__.__name__ + '()' obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) buf.add( Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info)) obs = obs_next assert len(buf) == min(bufsize, i + 1) assert buf.act.dtype == int assert buf.act.shape == (bufsize, 1) data, indices = buf.sample(bufsize * 2) assert (indices < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() b = ReplayBuffer(size=10) # neg bsz should return empty index assert b.sample_indices(-1).tolist() == [] ptr, ep_rew, ep_len, ep_idx = b.add( Batch(obs=1, act=1, rew=1, done=1, obs_next='str', info={ 'a': 3, 'b': { 'c': 5.0 } })) assert b.obs[0] == 1 assert b.done[0] assert b.obs_next[0] == 'str' assert np.all(b.obs[1:] == 0) assert np.all(b.obs_next[1:] == np.array(None)) assert b.info.a[0] == 3 and b.info.a.dtype == int assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float assert np.all(b.info.b.c[1:] == 0.0) assert ptr.shape == (1, ) and ptr[0] == 0 assert ep_rew.shape == (1, ) and ep_rew[0] == 1 assert ep_len.shape == (1, ) and ep_len[0] == 1 assert ep_idx.shape == (1, ) and ep_idx[0] == 0 # test extra keys pop up, the buffer should handle it dynamically batch = Batch(obs=2, act=2, rew=2, done=0, obs_next="str2", info={ "a": 4, "d": { "e": -np.inf } }) b.add(batch) info_keys = ["a", "b", "d"] assert set(b.info.keys()) == set(info_keys) assert b.info.a[1] == 4 and b.info.b.c[1] == 0 assert b.info.d.e[1] == -np.inf # test batch-style adding method, where len(batch) == 1 batch.done = [1] batch.info.e = np.zeros([1, 4]) batch = Batch.stack([batch]) ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) assert ptr.shape == (1, ) and ptr[0] == 2 assert ep_rew.shape == (1, ) and ep_rew[0] == 4 assert ep_len.shape == (1, ) and ep_len[0] == 2 assert ep_idx.shape == (1, ) and ep_idx[0] == 1 assert set(b.info.keys()) == set(info_keys + ["e"]) assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): b[22] # test prev / next assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) batch.done = [0] b.add(batch, buffer_ids=[0]) assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])