Ejemplo n.º 1
0
 def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                indice: np.ndarray) -> Batch:
     if self._rew_norm:
         mean, std = batch.rew.mean(), batch.rew.std()
         if not np.isclose(std, 0, 1e-2):
             batch.rew = (batch.rew - mean) / std
     v, v_, old_log_prob = [], [], []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False):
             v_.append(self.critic(b.obs_next))
             v.append(self.critic(b.obs))
             old_log_prob.append(self(b).dist.log_prob(
                 to_torch_as(b.act, v[0])))
     v_ = to_numpy(torch.cat(v_, dim=0))
     batch = self.compute_episodic_return(
         batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
         rew_norm=self._rew_norm)
     batch.v = torch.cat(v, dim=0).flatten()  # old value
     batch.act = to_torch_as(batch.act, v[0])
     batch.logp_old = torch.cat(old_log_prob, dim=0)
     batch.returns = to_torch_as(batch.returns, v[0])
     batch.adv = batch.returns - batch.v
     if self._rew_norm:
         mean, std = batch.adv.mean(), batch.adv.std()
         if not np.isclose(std.item(), 0, 1e-2):
             batch.adv = (batch.adv - mean) / std
     return batch
Ejemplo n.º 2
0
 def learn(  # type: ignore
         self, batch: Batch, batch_size: int, repeat: int,
         **kwargs: Any) -> Dict[str, List[float]]:
     losses, actor_losses, vf_losses, ent_losses = [], [], [], []
     for _ in range(repeat):
         for b in batch.split(batch_size, merge_last=True):
             self.optim.zero_grad()
             dist = self(b).dist
             v = self.critic(b.obs).flatten()
             a = to_torch_as(b.act, v)
             r = to_torch_as(b.returns, v)
             log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
             a_loss = -(log_prob * (r - v).detach()).mean()
             vf_loss = F.mse_loss(r, v)  # type: ignore
             ent_loss = dist.entropy().mean()
             loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
             loss.backward()
             if self._grad_norm is not None:
                 nn.utils.clip_grad_norm_(
                     list(self.actor.parameters()) +
                     list(self.critic.parameters()),
                     max_norm=self._grad_norm,
                 )
             self.optim.step()
             actor_losses.append(a_loss.item())
             vf_losses.append(vf_loss.item())
             ent_losses.append(ent_loss.item())
             losses.append(loss.item())
     return {
         "loss": losses,
         "loss/actor": actor_losses,
         "loss/vf": vf_losses,
         "loss/ent": ent_losses,
     }
Ejemplo n.º 3
0
 def learn(self, batch: Batch, batch_size: int, repeat: int,
           **kwargs) -> Dict[str, List[float]]:
     self._batch = batch_size
     r = batch.returns
     if self._rew_norm and not np.isclose(r.std(), 0):
         batch.returns = (r - r.mean()) / r.std()
     losses, actor_losses, vf_losses, ent_losses = [], [], [], []
     for _ in range(repeat):
         for b in batch.split(batch_size):
             self.optim.zero_grad()
             dist = self(b).dist
             v = self.critic(b.obs).squeeze(-1)
             a = to_torch_as(b.act, v)
             r = to_torch_as(b.returns, v)
             a_loss = -(dist.log_prob(a).reshape(v.shape) *
                        (r - v).detach()).mean()
             vf_loss = F.mse_loss(r, v)
             ent_loss = dist.entropy().mean()
             loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
             loss.backward()
             if self._grad_norm is not None:
                 nn.utils.clip_grad_norm_(list(self.actor.parameters()) +
                                          list(self.critic.parameters()),
                                          max_norm=self._grad_norm)
             self.optim.step()
             actor_losses.append(a_loss.item())
             vf_losses.append(vf_loss.item())
             ent_losses.append(ent_loss.item())
             losses.append(loss.item())
     return {
         'loss': losses,
         'loss/actor': actor_losses,
         'loss/vf': vf_losses,
         'loss/ent': ent_losses,
     }
Ejemplo n.º 4
0
 def _compute_returns(self, batch: Batch, buffer: ReplayBuffer,
                      indice: np.ndarray) -> Batch:
     v_s, v_s_ = [], []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False, merge_last=True):
             v_s.append(self.critic(b.obs))
             v_s_.append(self.critic(b.obs_next))
     batch.v_s = torch.cat(v_s, dim=0).flatten()  # old value
     v_s = batch.v_s.cpu().numpy()
     v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
     # when normalizing values, we do not minus self.ret_rms.mean to be numerically
     # consistent with OPENAI baselines' value normalization pipeline. Emperical
     # study also shows that "minus mean" will harm performances a tiny little bit
     # due to unknown reasons (on Mujoco envs, not confident, though).
     if self._rew_norm:  # unnormalize v_s & v_s_
         v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
         v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
     unnormalized_returns, advantages = self.compute_episodic_return(
         batch,
         buffer,
         indice,
         v_s_,
         v_s,
         gamma=self._gamma,
         gae_lambda=self._lambda)
     if self._rew_norm:
         batch.returns = unnormalized_returns / \
             np.sqrt(self.ret_rms.var + self._eps)
         self.ret_rms.update(unnormalized_returns)
     else:
         batch.returns = unnormalized_returns
     batch.returns = to_torch_as(batch.returns, batch.v_s)
     batch.adv = to_torch_as(advantages, batch.v_s)
     return batch
Ejemplo n.º 5
0
 def process_fn(
     self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
 ) -> Batch:
     v_s, v_s_, old_log_prob = [], [], []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False, merge_last=True):
             v_s.append(self.critic(b.obs))
             v_s_.append(self.critic(b.obs_next))
             old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0])))
     batch.v_s = torch.cat(v_s, dim=0).flatten()  # old value
     v_s = to_numpy(batch.v_s)
     v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
     if self._rew_norm:  # unnormalize v_s & v_s_
         v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
         v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
     unnormalized_returns, advantages = self.compute_episodic_return(
         batch, buffer, indice, v_s_, v_s,
         gamma=self._gamma, gae_lambda=self._lambda)
     if self._rew_norm:
         batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
             np.sqrt(self.ret_rms.var + self._eps)
         self.ret_rms.update(unnormalized_returns)
         mean, std = np.mean(advantages), np.std(advantages)
         advantages = (advantages - mean) / std  # per-batch norm
     else:
         batch.returns = unnormalized_returns
     batch.act = to_torch_as(batch.act, batch.v_s)
     batch.logp_old = torch.cat(old_log_prob, dim=0)
     batch.returns = to_torch_as(batch.returns, batch.v_s)
     batch.adv = to_torch_as(advantages, batch.v_s)
     return batch
Ejemplo n.º 6
0
    def compute_nstep_return(
        batch: Batch,
        buffer: ReplayBuffer,
        indice: np.ndarray,
        target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
        gamma: float = 0.99,
        n_step: int = 1,
        rew_norm: bool = False,
    ) -> Batch:
        r"""Compute n-step return for Q-learning targets.

        .. math::
            G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
            \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})

        where :math:`\gamma` is the discount factor,
        :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
        :math:`t`.

        :param batch: a data batch, which is equal to buffer[indice].
        :type batch: :class:`~tianshou.data.Batch`
        :param buffer: a data buffer which contains several full-episode data
            chronologically.
        :type buffer: :class:`~tianshou.data.ReplayBuffer`
        :param indice: sampled timestep.
        :type indice: numpy.ndarray
        :param function target_q_fn: a function receives :math:`t+n-1` step's
            data and compute target Q value.
        :param float gamma: the discount factor, should be in [0, 1], defaults
            to 0.99.
        :param int n_step: the number of estimation step, should be an int
            greater than 0, defaults to 1.
        :param bool rew_norm: normalize the reward to Normal(0, 1), defaults
            to False.

        :return: a Batch. The result will be stored in batch.returns as a
            torch.Tensor with shape (bsz, ).
        """
        rew = buffer.rew
        if rew_norm:
            bfr = rew[:min(len(buffer), 1000)]  # avoid large buffer
            mean, std = bfr.mean(), bfr.std()
            if np.isclose(std, 0, 1e-2):
                mean, std = 0.0, 1.0
        else:
            mean, std = 0.0, 1.0
        buf_len = len(buffer)
        terminal = (indice + n_step - 1) % buf_len
        target_q_torch = target_q_fn(buffer, terminal).flatten()  # (bsz, )
        target_q = to_numpy(target_q_torch)

        target_q = _nstep_return(rew, buffer.done, target_q, indice, gamma,
                                 n_step, len(buffer), mean, std)

        batch.returns = to_torch_as(target_q, target_q_torch)
        # prio buffer update
        if isinstance(buffer, PrioritizedReplayBuffer):
            batch.weight = to_torch_as(batch.weight, target_q_torch)
        return batch
Ejemplo n.º 7
0
    def compute_nstep_return(batch: Batch,
                             buffer: ReplayBuffer,
                             indice: np.ndarray,
                             target_q_fn: Callable[[ReplayBuffer, np.ndarray],
                                                   torch.Tensor],
                             gamma: float = 0.99,
                             n_step: int = 1,
                             rew_norm: bool = False) -> np.ndarray:
        r"""Compute n-step return for Q-learning targets:

        .. math::
            G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
            \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})

        , where :math:`\gamma` is the discount factor,
        :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
        :math:`t`.

        :param batch: a data batch, which is equal to buffer[indice].
        :type batch: :class:`~tianshou.data.Batch`
        :param buffer: a data buffer which contains several full-episode data
            chronologically.
        :type buffer: :class:`~tianshou.data.ReplayBuffer`
        :param indice: sampled timestep.
        :type indice: numpy.ndarray
        :param function target_q_fn: a function receives :math:`t+n-1` step's
            data and compute target Q value.
        :param float gamma: the discount factor, should be in [0, 1], defaults
            to 0.99.
        :param int n_step: the number of estimation step, should be an int
            greater than 0, defaults to 1.
        :param bool rew_norm: normalize the reward to Normal(0, 1), defaults
            to ``False``.

        :return: a Batch. The result will be stored in batch.returns as a
            torch.Tensor with shape (bsz, ).
        """
        if rew_norm:
            bfr = buffer.rew[:min(len(buffer), 1000)]  # avoid large buffer
            mean, std = bfr.mean(), bfr.std()
            if np.isclose(std, 0):
                mean, std = 0, 1
        else:
            mean, std = 0, 1
        returns = np.zeros_like(indice)
        gammas = np.zeros_like(indice) + n_step
        done, rew, buf_len = buffer.done, buffer.rew, len(buffer)
        for n in range(n_step - 1, -1, -1):
            now = (indice + n) % buf_len
            gammas[done[now] > 0] = n
            returns[done[now] > 0] = 0
            returns = (rew[now] - mean) / std + gamma * returns
        terminal = (indice + n_step - 1) % buf_len
        target_q = target_q_fn(buffer, terminal).squeeze()
        target_q[gammas != n_step] = 0
        returns = to_torch_as(returns, target_q)
        gammas = to_torch_as(gamma**gammas, target_q)
        batch.returns = target_q * gammas + returns
        return batch
Ejemplo n.º 8
0
    def compute_nstep_return(
        batch: Batch,
        buffer: ReplayBuffer,
        indice: np.ndarray,
        target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
        gamma: float = 0.99,
        n_step: int = 1,
        rew_norm: bool = False,
        use_mixed: bool = False,
    ) -> Batch:
        r"""Compute n-step return for Q-learning targets.

        .. math::
            G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
            \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})

        where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`,
        :math:`d_t` is the done flag of step :math:`t`.

        :param Batch batch: a data batch, which is equal to buffer[indice].
        :param ReplayBuffer buffer: the data buffer.
        :param function target_q_fn: a function which compute target Q value
            of "obs_next" given data buffer and wanted indices.
        :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
        :param int n_step: the number of estimation step, should be an int greater
            than 0. Default to 1.
        :param bool rew_norm: normalize the reward to Normal(0, 1), Default to False.

        :return: a Batch. The result will be stored in batch.returns as a
            torch.Tensor with the same shape as target_q_fn's return tensor.
        """
        assert not rew_norm, \
            "Reward normalization in computing n-step returns is unsupported now."
        rew = buffer.rew
        bsz = len(indice)
        indices = [indice]
        for _ in range(n_step - 1):
            indices.append(buffer.next(indices[-1]))
        indices = np.stack(indices)
        # terminal indicates buffer indexes nstep after 'indice',
        # and are truncated at the end of each episode
        terminal = indices[-1]
        with autocast(enabled=use_mixed):
            with torch.no_grad():
                target_q_torch = target_q_fn(buffer, terminal)  # (bsz, ?)
        target_q = to_numpy(target_q_torch.float().reshape(bsz, -1))
        target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(
            -1, 1)
        end_flag = buffer.done.copy()
        end_flag[buffer.unfinished_index()] = True
        target_q = _nstep_return(rew, end_flag, target_q, indices, gamma,
                                 n_step)

        batch.returns = to_torch_as(target_q, target_q_torch)
        if hasattr(batch, "weight"):  # prio buffer update
            batch.weight = to_torch_as(batch.weight, target_q_torch)
        return batch
Ejemplo n.º 9
0
 def learn(  # type: ignore
         self, batch: Batch, batch_size: int, repeat: int,
         **kwargs: Any) -> Dict[str, List[float]]:
     losses = []
     for _ in range(repeat):
         for b in batch.split(batch_size, merge_last=True):
             self.optim.zero_grad()
             dist = self(b).dist
             a = to_torch_as(b.act, dist.logits)
             r = to_torch_as(b.returns, dist.logits)
             log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
             loss = -(log_prob * r).mean()
             loss.backward()
             self.optim.step()
             losses.append(loss.item())
     return {"loss": losses}
Ejemplo n.º 10
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     q = self(batch).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns.flatten(), q)
     weight = to_torch_as(weight, q)
     td = r - q
     loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {"loss": loss.item()}
Ejemplo n.º 11
0
    def forward(self,
                batch: Batch,
                state: Optional[Union[dict, Batch, np.ndarray]] = None,
                model: str = 'actor',
                input: str = 'obs',
                explorating: bool = True,
                **kwargs) -> Batch:
        """Compute action over the given batch data.

        :return: A :class:`~tianshou.data.Batch` which has 2 keys:

            * ``act`` the action.
            * ``state`` the hidden state.

        .. seealso::

            Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
            more detailed explanation.
        """
        model = getattr(self, model)
        obs = getattr(batch, input)
        logits, h = model(obs, state=state, info=batch.info)
        actions = torch.tanh(logits)
        if self.training and explorating:
            actions = actions + to_torch_as(self._noise(actions.shape),
                                            actions)

        actions = actions.clamp(self._range[0], self._range[1])
        return Batch(act=actions, state=h)
Ejemplo n.º 12
0
 def forward(  # type: ignore
     self,
     batch: Batch,
     state: Optional[Union[dict, Batch, np.ndarray]] = None,
     input: str = "obs",
     **kwargs: Any,
 ) -> Batch:
     obs = batch[input]
     logits, h = self.actor(obs, state=state, info=batch.info)
     assert isinstance(logits, tuple)
     dist = Independent(Normal(*logits), 1)
     if self._deterministic_eval and not self.training:
         act = logits[0]
     else:
         act = dist.rsample()
     log_prob = dist.log_prob(act).unsqueeze(-1)
     # apply correction for Tanh squashing when computing logprob from Gaussian
     # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
     # in appendix C to get some understanding of this equation.
     if self.action_scaling and self.action_space is not None:
         action_scale = to_torch_as(
             (self.action_space.high - self.action_space.low) / 2.0, act)
     else:
         action_scale = 1.0  # type: ignore
     squashed_action = torch.tanh(act)
     log_prob = log_prob - torch.log(action_scale *
                                     (1 - squashed_action.pow(2)) +
                                     self.__eps).sum(-1, keepdim=True)
     return Batch(logits=logits,
                  act=squashed_action,
                  state=h,
                  dist=dist,
                  log_prob=log_prob)
Ejemplo n.º 13
0
    def forward(
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        model: str = "actor",
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        """Compute action over the given batch data.

        :return: A :class:`~tianshou.data.Batch` which has 2 keys:

            * ``act`` the action.
            * ``state`` the hidden state.

        .. seealso::

            Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
            more detailed explanation.
        """
        model = getattr(self, model)
        obs = batch[input]
        actions, h = model(obs, state=state, info=batch.info)
        actions += self._action_bias
        if self._noise and not self.updating:
            actions += to_torch_as(self._noise(actions.shape), actions)
        actions = actions.clamp(self._range[0], self._range[1])
        return Batch(act=actions, state=h)
Ejemplo n.º 14
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     # critic 1
     current_q1 = self.critic1(batch.obs, batch.act)
     target_q = to_torch_as(batch.returns, current_q1)[:, None]
     critic1_loss = F.mse_loss(current_q1, target_q)
     self.critic1_optim.zero_grad()
     critic1_loss.backward()
     self.critic1_optim.step()
     # critic 2
     current_q2 = self.critic2(batch.obs, batch.act)
     critic2_loss = F.mse_loss(current_q2, target_q)
     self.critic2_optim.zero_grad()
     critic2_loss.backward()
     self.critic2_optim.step()
     # actor
     obs_result = self(batch)
     a = obs_result.act
     current_q1a = self.critic1(batch.obs, a)
     current_q2a = self.critic2(batch.obs, a)
     actor_loss = (self._alpha * obs_result.log_prob -
                   torch.min(current_q1a, current_q2a)).mean()
     self.actor_optim.zero_grad()
     actor_loss.backward()
     self.actor_optim.step()
     self.sync_weight()
     return {
         'loss/actor': actor_loss.item(),
         'loss/critic1': critic1_loss.item(),
         'loss/critic2': critic2_loss.item(),
     }
Ejemplo n.º 15
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        # critic 1&2
        td1, critic1_loss = self._mse_optimizer(
            batch, self.critic1, self.critic1_optim
        )
        td2, critic2_loss = self._mse_optimizer(
            batch, self.critic2, self.critic2_optim
        )
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        # actor
        if self._cnt % self._freq == 0:
            act = self(batch, eps=0.0).act
            q_value = self.critic1(batch.obs, act)
            lmbda = self._alpha / q_value.abs().mean().detach()
            actor_loss = -lmbda * q_value.mean() + F.mse_loss(
                act, to_torch_as(batch.act, act)
            )
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self._last = actor_loss.item()
            self.actor_optim.step()
            self.sync_weight()
        self._cnt += 1
        return {
            "loss/actor": self._last,
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
        }
Ejemplo n.º 16
0
 def forward(  # type: ignore
     self,
     batch: Batch,
     state: Optional[Union[dict, Batch, np.ndarray]] = None,
     input: str = "obs",
     **kwargs: Any,
 ) -> Batch:
     obs = batch[input]
     logits, h = self.actor(obs, state=state, info=batch.info)
     assert isinstance(logits, tuple)
     dist = Independent(Normal(*logits), 1)
     if self._deterministic_eval and not self.training:
         x = logits[0]
     else:
         x = dist.rsample()
     y = torch.tanh(x)
     act = y * self._action_scale + self._action_bias
     y = self._action_scale * (1 - y.pow(2)) + self.__eps
     log_prob = dist.log_prob(x).unsqueeze(-1)
     log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
     if self._noise is not None and self.training and not self.updating:
         act += to_torch_as(self._noise(act.shape), act)
     act = act.clamp(self._range[0], self._range[1])
     return Batch(
         logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
Ejemplo n.º 17
0
 def compute_q_value(self, logits: torch.Tensor,
                     mask: Optional[np.ndarray]) -> torch.Tensor:
     """Compute the q value based on the network's raw output and action mask."""
     if mask is not None:
         # the masked q value should be smaller than logits.min()
         min_value = logits.min() - logits.max() - 1.0
         logits = logits + to_torch_as(1 - mask, logits) * min_value
     return logits
Ejemplo n.º 18
0
 def learn(self, batch: Batch, batch_size: int, repeat: int,
           **kwargs) -> Dict[str, List[float]]:
     losses = []
     r = batch.returns
     if self._rew_norm and not np.isclose(r.std(), 0):
         batch.returns = (r - r.mean()) / r.std()
     for _ in range(repeat):
         for b in batch.split(batch_size):
             self.optim.zero_grad()
             dist = self(b).dist
             a = to_torch_as(b.act, dist.logits)
             r = to_torch_as(b.returns, dist.logits)
             loss = -(dist.log_prob(a) * r).sum()
             loss.backward()
             self.optim.step()
             losses.append(loss.item())
     return {'loss': losses}
Ejemplo n.º 19
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     q = self(batch).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns, q)
     if hasattr(batch, 'update_weight'):
         td = r - q
         batch.update_weight(batch.indice, to_numpy(td))
         impt_weight = to_torch_as(batch.impt_weight, q)
         loss = (td.pow(2) * impt_weight).mean()
     else:
         loss = F.mse_loss(q, r)
     loss.backward()
     self.optim.step()
     self._cnt += 1
     return {'loss': loss.item()}
Ejemplo n.º 20
0
 def _target_q(self, buffer: ReplayBuffer,
               indice: np.ndarray) -> torch.Tensor:
     batch = buffer[indice]  # batch.obs: s_{t+n}
     with torch.no_grad():
         obs_next_result = self(batch, input='obs_next', explorating=False)
         a_ = obs_next_result.act
         batch.act = to_torch_as(batch.act, a_)
         target_q = torch.min(
             self.critic1_old(batch.obs_next, a_),
             self.critic2_old(batch.obs_next, a_),
         ) - self._alpha * obs_next_result.log_prob
     return target_q
Ejemplo n.º 21
0
    def learn(  # type: ignore
            self, batch: Batch, batch_size: int, repeat: int,
            **kwargs: Any) -> Dict[str, List[float]]:
        losses = []
        for _ in range(repeat):
            for b in batch.split(batch_size, merge_last=True):
                self.optim.zero_grad()
                result = self(b)
                dist = result.dist
                a = to_torch_as(b.act, result.act)
                r = to_torch_as(b.returns, result.act)
                log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
                loss = -(log_prob * r).mean()
                loss.backward()
                self.optim.step()
                losses.append(loss.item())
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {"loss": losses}
Ejemplo n.º 22
0
    def learn(self, batch:Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, float]:
        
        for _ in range(repeat):
            for b in batch.split(batch_size, merge_last=True):
                self.optim.zero_grad()
                state_values = self.value_net(b.obs, b.act).flatten()
                advantages = b.advantages.flatten()
                #actions = b.act
                dist = self(b).dist
                actions = to_torch_as(b.act, dist.logits)
                rewards = to_torch_as(b.returns, dist.logits)
                # Advantage estimation
                adv = advantages - state_values
                adv_squared = (adv.pow(2)).mean()

                # Value loss
                v_loss = 0.5 * adv_squared

                # Policy loss
                # Update averaged advantage norm

                # Exponentially weighted advantages
                exp_advs = 
                # log\pi_\theta(a|s)
                log_prob = dist.log_prob(a).reshape(len(rewards), -1).transpose(0, 1)
                p_loss = - 1.0 * (log_prob * exp_advs.detach()).mean()

                # Combine both losses
                loss = p_loss + self.vf_coeff * v_loss

                
                loss.backward()
                self.optim.step()

        return {
            "policy_loss": p_loss.item(),
            "vf_loss": v_loss.item(),
            "total_loss": loss.item(),
            "vf_explained_var": explained_variance.item()
        }
Ejemplo n.º 23
0
 def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                indice: np.ndarray) -> Batch:
     if self._recompute_adv:
         # buffer input `buffer` and `indice` to be used in `learn()`.
         self._buffer, self._indice = buffer, indice
     batch = self._compute_returns(batch, buffer, indice)
     batch.act = to_torch_as(batch.act, batch.v_s)
     old_log_prob = []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False, merge_last=True):
             old_log_prob.append(self(b).dist.log_prob(b.act))
     batch.logp_old = torch.cat(old_log_prob, dim=0)
     return batch
Ejemplo n.º 24
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        if self._target and self._iter % self._freq == 0:
            self.sync_weight()
        self.optim.zero_grad()

        # compute dqn loss
        weight = batch.pop("weight", 1.0)
        q = self(batch).logits
        q = q[np.arange(len(q)), batch.act]
        r = to_torch_as(batch.returns.flatten(), q)
        weight = to_torch_as(weight, q)
        td = r - q
        dqn_loss = (td.pow(2) * weight).mean()
        batch.weight = td  # prio-buffer

        # compute opd loss
        slow_preds = self(batch, output_dqn=False, output_opd=True).logits
        fast_preds = self(batch,
                          input="fast_obs",
                          output_dqn=False,
                          output_opd=True).logits
        # act_dim + 1 classes. The last class indicate the state is off-policy
        slow_label = torch.ones_like(
            slow_preds[:, 0], dtype=torch.int64) * int(self.model.output_dim)
        fast_label = to_torch_as(torch.tensor(batch.fast_act), slow_label)
        opd_loss = F.cross_entropy(slow_preds, slow_label) + \
                    F.cross_entropy(fast_preds, fast_label)

        # compute loss and back prop
        loss = dqn_loss + self.opd_loss_coeff * opd_loss
        loss.backward()
        self.optim.step()
        self._iter += 1
        return {
            "loss": loss.item(),
            "dqn_loss": dqn_loss.item(),
            "opd_loss": opd_loss.item(),
        }
Ejemplo n.º 25
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop('weight', 1.)
     q = self(batch, eps=0., is_learning=True).logits
     # q = torch.tensor(q, requires_grad=True)
     # q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.policy.reshape((-1, 1)), q)
     loss = F.mse_loss(r, q)
     loss.backward()
     self.optim.step()
     self._cnt += 1
     return {'loss': loss.item()}
Ejemplo n.º 26
0
    def _compute_return(
        self,
        batch: Batch,
        buffer: ReplayBuffer,
        indice: np.ndarray,
        gamma: float = 0.99,
    ) -> Batch:
        rew = batch.rew
        with torch.no_grad():
            target_q_torch = self._target_q(buffer, indice)  # (bsz, ?)
        target_q = to_numpy(target_q_torch)
        end_flag = buffer.done.copy()
        end_flag[buffer.unfinished_index()] = True
        end_flag = end_flag[indice]
        mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q
        _target_q = rew + gamma * mean_target_q * (1 - end_flag)
        target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1)
        target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1)

        batch.returns = to_torch_as(target_q, target_q_torch)
        if hasattr(batch, "weight"):  # prio buffer update
            batch.weight = to_torch_as(batch.weight, target_q_torch)
        return batch
Ejemplo n.º 27
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     q = self(batch, eps=0.).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns, q).flatten()
     td = r - q
     loss = (td.pow(2) * batch.weight).mean()
     batch.weight = td  # prio-buffer
     loss.backward()
     self.optim.step()
     self._cnt += 1
     return {'loss': loss.item()}
Ejemplo n.º 28
0
 def forward(self, s, a=None):
     """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
     s = to_torch(s, device=self.device, dtype=torch.float32)
     # s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
     # In short, the tensor's shape in training phase is longer than which
     # in evaluation phase.
     assert len(s.shape) == 3
     self.nn.flatten_parameters()
     s, (h, c) = self.nn(s)
     s = s[:, -1]
     if a is not None:
         a = to_torch_as(a, s)
         s = torch.cat([s, a], dim=1)
     s = self.fc2(s)
     return s
Ejemplo n.º 29
0
 def forward(
     self,
     s: Union[np.ndarray, torch.Tensor],
     a: Optional[Union[np.ndarray, torch.Tensor]] = None,
     info: Dict[str, Any] = {},
 ) -> torch.Tensor:
     """Mapping: (s, a) -> logits -> Q(s, a)."""
     s = to_torch(s, device=self.device, dtype=torch.float32)
     s = s.flatten(1)
     if a is not None:
         a = to_torch_as(a, s)
         a = a.flatten(1)
         s = torch.cat([s, a], dim=1)
     logits, h = self.preprocess(s)
     logits = self.last(logits)
     return logits
Ejemplo n.º 30
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     current_q = self.critic(batch.obs, batch.act)
     target_q = to_torch_as(batch.returns, current_q)
     target_q = target_q[:, None]
     critic_loss = F.mse_loss(current_q, target_q)
     self.critic_optim.zero_grad()
     critic_loss.backward()
     self.critic_optim.step()
     actor_loss = -self.critic(batch.obs, self(batch, eps=0).act).mean()
     self.actor_optim.zero_grad()
     actor_loss.backward()
     self.actor_optim.step()
     self.sync_weight()
     return {
         'loss/actor': actor_loss.item(),
         'loss/critic': critic_loss.item(),
     }