Пример #1
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     weight = batch.pop("weight", 1.0)
     # critic 1
     current_q1 = self.critic1(batch.obs, batch.act).flatten()
     target_q = batch.returns.flatten()
     td1 = current_q1 - target_q
     critic1_loss = (td1.pow(2) * weight).mean()
     # critic1_loss = F.mse_loss(current_q1, target_q)
     self.critic1_optim.zero_grad()
     critic1_loss.backward()
     self.critic1_optim.step()
     # critic 2
     current_q2 = self.critic2(batch.obs, batch.act).flatten()
     td2 = current_q2 - target_q
     critic2_loss = (td2.pow(2) * weight).mean()
     # critic2_loss = F.mse_loss(current_q2, target_q)
     self.critic2_optim.zero_grad()
     critic2_loss.backward()
     self.critic2_optim.step()
     batch.weight = (td1 + td2) / 2.0  # prio-buffer
     if self._cnt % self._freq == 0:
         actor_loss = -self.critic1(batch.obs,
                                    self(batch, eps=0.0).act).mean()
         self.actor_optim.zero_grad()
         actor_loss.backward()
         self._last = actor_loss.item()
         self.actor_optim.step()
         self.sync_weight()
     self._cnt += 1
     return {
         "loss/actor": self._last,
         "loss/critic1": critic1_loss.item(),
         "loss/critic2": critic2_loss.item(),
     }
Пример #2
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        # critic 1&2
        td1, critic1_loss = self._mse_optimizer(batch, self.critic1,
                                                self.critic1_optim,
                                                self.scaler, self.use_mixed)
        td2, critic2_loss = self._mse_optimizer(batch, self.critic2,
                                                self.critic2_optim,
                                                self.scaler, self.use_mixed)
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        # actor
        if self._cnt % self._freq == 0:
            actor_loss = -self.critic1(batch.obs,
                                       self(batch, eps=0.0).act).mean()
            self.actor_optim.zero_grad()
            self.scaler.scale(actor_loss).backward()
            # actor_loss.backward()
            self._last = actor_loss.item()
            self.scaler.step(self.actor_optim)
            # self.actor_optim.step()
            self.sync_weight()
        self.scaler.update(
        )  # Check this if this is correct, with sync_weight above as well
        self._cnt += 1
        return {
            "loss/actor": self._last,
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
        }
Пример #3
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        # critic 1&2
        td1, critic1_loss = self._mse_optimizer(
            batch, self.critic1, self.critic1_optim
        )
        td2, critic2_loss = self._mse_optimizer(
            batch, self.critic2, self.critic2_optim
        )
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        # actor
        if self._cnt % self._freq == 0:
            act = self(batch, eps=0.0).act
            q_value = self.critic1(batch.obs, act)
            lmbda = self._alpha / q_value.abs().mean().detach()
            actor_loss = -lmbda * q_value.mean() + F.mse_loss(
                act, to_torch_as(batch.act, act)
            )
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self._last = actor_loss.item()
            self.actor_optim.step()
            self.sync_weight()
        self._cnt += 1
        return {
            "loss/actor": self._last,
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
        }
Пример #4
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     all_dist = self(batch).logits
     act = to_torch(batch.act, dtype=torch.long, device=all_dist.device)
     curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2)
     target_dist = batch.returns.unsqueeze(1)
     # calculate each element's difference between curr_dist and target_dist
     u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
     huber_loss = (
         u * (self.tau_hat -
              (target_dist - curr_dist).detach().le(0.).float()).abs()
     ).sum(-1).mean(1)
     qr_loss = (huber_loss * weight).mean()
     # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
     # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
     batch.weight = u.detach().abs().sum(-1).mean(1)  # prio-buffer
     # add CQL loss
     q = self.compute_q_value(all_dist, None)
     dataset_expec = q.gather(1, act.unsqueeze(1)).mean()
     negative_sampling = q.logsumexp(1).mean()
     min_q_loss = negative_sampling - dataset_expec
     loss = qr_loss + min_q_loss * self._min_q_weight
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {
         "loss": loss.item(),
         "loss/qr": qr_loss.item(),
         "loss/cql": min_q_loss.item(),
     }
Пример #5
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     action_batch = self(batch)
     curr_dist, taus = action_batch.logits, action_batch.taus
     act = batch.act
     curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
     target_dist = batch.returns.unsqueeze(1)
     # calculate each element's difference between curr_dist and target_dist
     dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
     huber_loss = (
         dist_diff *
         (taus.unsqueeze(2) -
          (target_dist - curr_dist).detach().le(0.).float()).abs()
     ).sum(-1).mean(1)
     loss = (huber_loss * weight).mean()
     # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
     # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
     batch.weight = dist_diff.detach().abs().sum(-1).mean(1)  # prio-buffer
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {"loss": loss.item()}
Пример #6
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        weight = batch.pop("weight", 1.0)
        target_q = batch.returns.flatten()
        act = to_torch(batch.act[:, np.newaxis],
                       device=target_q.device,
                       dtype=torch.long)

        # critic 1
        current_q1 = self.critic1(batch.obs).gather(1, act).flatten()
        td1 = current_q1 - target_q
        critic1_loss = (td1.pow(2) * weight).mean()

        self.critic1_optim.zero_grad()
        critic1_loss.backward()
        self.critic1_optim.step()

        # critic 2
        current_q2 = self.critic2(batch.obs).gather(1, act).flatten()
        td2 = current_q2 - target_q
        critic2_loss = (td2.pow(2) * weight).mean()

        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        # actor
        dist = self(batch).dist
        entropy = dist.entropy()
        with torch.no_grad():
            current_q1a = self.critic1(batch.obs)
            current_q2a = self.critic2(batch.obs)
            q = torch.min(current_q1a, current_q2a)
        actor_loss = -(self._alpha * entropy +
                       (dist.probs * q).sum(dim=-1)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_prob = -entropy.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_prob).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

        self.sync_weight()

        result = {
            "loss/actor": actor_loss.item(),
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
        }
        if self._is_auto_alpha:
            result["loss/alpha"] = alpha_loss.item()
            result["alpha"] = self._alpha.item()  # type: ignore

        return result
Пример #7
0
    def compute_nstep_return(
        batch: Batch,
        buffer: ReplayBuffer,
        indice: np.ndarray,
        target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
        gamma: float = 0.99,
        n_step: int = 1,
        rew_norm: bool = False,
    ) -> Batch:
        r"""Compute n-step return for Q-learning targets.

        .. math::
            G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
            \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})

        where :math:`\gamma` is the discount factor,
        :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
        :math:`t`.

        :param batch: a data batch, which is equal to buffer[indice].
        :type batch: :class:`~tianshou.data.Batch`
        :param buffer: a data buffer which contains several full-episode data
            chronologically.
        :type buffer: :class:`~tianshou.data.ReplayBuffer`
        :param indice: sampled timestep.
        :type indice: numpy.ndarray
        :param function target_q_fn: a function receives :math:`t+n-1` step's
            data and compute target Q value.
        :param float gamma: the discount factor, should be in [0, 1], defaults
            to 0.99.
        :param int n_step: the number of estimation step, should be an int
            greater than 0, defaults to 1.
        :param bool rew_norm: normalize the reward to Normal(0, 1), defaults
            to False.

        :return: a Batch. The result will be stored in batch.returns as a
            torch.Tensor with shape (bsz, ).
        """
        rew = buffer.rew
        if rew_norm:
            bfr = rew[:min(len(buffer), 1000)]  # avoid large buffer
            mean, std = bfr.mean(), bfr.std()
            if np.isclose(std, 0, 1e-2):
                mean, std = 0.0, 1.0
        else:
            mean, std = 0.0, 1.0
        buf_len = len(buffer)
        terminal = (indice + n_step - 1) % buf_len
        target_q_torch = target_q_fn(buffer, terminal).flatten()  # (bsz, )
        target_q = to_numpy(target_q_torch)

        target_q = _nstep_return(rew, buffer.done, target_q, indice, gamma,
                                 n_step, len(buffer), mean, std)

        batch.returns = to_torch_as(target_q, target_q_torch)
        # prio buffer update
        if isinstance(buffer, PrioritizedReplayBuffer):
            batch.weight = to_torch_as(batch.weight, target_q_torch)
        return batch
Пример #8
0
    def compute_nstep_return(
        batch: Batch,
        buffer: ReplayBuffer,
        indice: np.ndarray,
        target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
        gamma: float = 0.99,
        n_step: int = 1,
        rew_norm: bool = False,
        use_mixed: bool = False,
    ) -> Batch:
        r"""Compute n-step return for Q-learning targets.

        .. math::
            G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
            \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})

        where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`,
        :math:`d_t` is the done flag of step :math:`t`.

        :param Batch batch: a data batch, which is equal to buffer[indice].
        :param ReplayBuffer buffer: the data buffer.
        :param function target_q_fn: a function which compute target Q value
            of "obs_next" given data buffer and wanted indices.
        :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
        :param int n_step: the number of estimation step, should be an int greater
            than 0. Default to 1.
        :param bool rew_norm: normalize the reward to Normal(0, 1), Default to False.

        :return: a Batch. The result will be stored in batch.returns as a
            torch.Tensor with the same shape as target_q_fn's return tensor.
        """
        assert not rew_norm, \
            "Reward normalization in computing n-step returns is unsupported now."
        rew = buffer.rew
        bsz = len(indice)
        indices = [indice]
        for _ in range(n_step - 1):
            indices.append(buffer.next(indices[-1]))
        indices = np.stack(indices)
        # terminal indicates buffer indexes nstep after 'indice',
        # and are truncated at the end of each episode
        terminal = indices[-1]
        with autocast(enabled=use_mixed):
            with torch.no_grad():
                target_q_torch = target_q_fn(buffer, terminal)  # (bsz, ?)
        target_q = to_numpy(target_q_torch.float().reshape(bsz, -1))
        target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(
            -1, 1)
        end_flag = buffer.done.copy()
        end_flag[buffer.unfinished_index()] = True
        target_q = _nstep_return(rew, end_flag, target_q, indices, gamma,
                                 n_step)

        batch.returns = to_torch_as(target_q, target_q_torch)
        if hasattr(batch, "weight"):  # prio buffer update
            batch.weight = to_torch_as(batch.weight, target_q_torch)
        return batch
Пример #9
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        weight = batch.pop("weight", 1.0)

        # critic 1
        current_q1 = self.critic1(batch.obs, batch.act).flatten()
        target_q = batch.returns.flatten()
        td1 = current_q1 - target_q
        critic1_loss = (td1.pow(2) * weight).mean()
        # critic1_loss = F.mse_loss(current_q1, target_q)
        self.critic1_optim.zero_grad()
        critic1_loss.backward()
        self.critic1_optim.step()

        # critic 2
        current_q2 = self.critic2(batch.obs, batch.act).flatten()
        td2 = current_q2 - target_q
        critic2_loss = (td2.pow(2) * weight).mean()
        # critic2_loss = F.mse_loss(current_q2, target_q)
        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        # actor
        obs_result = self(batch)
        a = obs_result.act
        current_q1a = self.critic1(batch.obs, a).flatten()
        current_q2a = self.critic2(batch.obs, a).flatten()
        actor_loss = (self._alpha * obs_result.log_prob.flatten()
                      - torch.min(current_q1a, current_q2a)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_prob = obs_result.log_prob.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_prob).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

        self.sync_weight()

        result = {
            "loss/actor": actor_loss.item(),
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
        }
        if self._is_auto_alpha:
            result["loss/alpha"] = alpha_loss.item()
            result["alpha"] = self._alpha.item()  # type: ignore

        return result
Пример #10
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        if self._target and self._iter % self._freq == 0:
            self.sync_weight()
        weight = batch.pop("weight", 1.0)
        out = self(batch)
        curr_dist_orig = out.logits
        taus, tau_hats = out.fractions.taus, out.fractions.tau_hats
        act = batch.act
        curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2)
        target_dist = batch.returns.unsqueeze(1)
        # calculate each element's difference between curr_dist and target_dist
        dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
        huber_loss = (
            dist_diff *
            (tau_hats.unsqueeze(2) -
             (target_dist - curr_dist).detach().le(0.).float()).abs()
        ).sum(-1).mean(1)
        quantile_loss = (huber_loss * weight).mean()
        # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
        # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
        batch.weight = dist_diff.detach().abs().sum(-1).mean(1)  # prio-buffer
        # calculate fraction loss
        with torch.no_grad():
            sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :]
            sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :]
            # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
            # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169
            values_1 = sa_quantiles - sa_quantile_hats[:, :-1]
            signs_1 = sa_quantiles > torch.cat(
                [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1)

            values_2 = sa_quantiles - sa_quantile_hats[:, 1:]
            signs_2 = sa_quantiles < torch.cat(
                [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1)

            gradient_of_taus = (torch.where(signs_1, values_1, -values_1) +
                                torch.where(signs_2, values_2, -values_2))
        fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean()
        # calculate entropy loss
        entropy_loss = out.fractions.entropies.mean()
        fraction_entropy_loss = fraction_loss - self._ent_coef * entropy_loss
        self._fraction_optim.zero_grad()
        fraction_entropy_loss.backward(retain_graph=True)
        self._fraction_optim.step()
        self.optim.zero_grad()
        quantile_loss.backward()
        self.optim.step()
        self._iter += 1
        return {
            "loss": quantile_loss.item() + fraction_entropy_loss.item(),
            "loss/quantile": quantile_loss.item(),
            "loss/fraction": fraction_loss.item(),
            "loss/entropy": entropy_loss.item()
        }
Пример #11
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     # critic
     td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
     batch.weight = td  # prio-buffer
     # actor
     actor_loss = -self.critic(batch.obs, self(batch).act).mean()
     self.actor_optim.zero_grad()
     actor_loss.backward()
     self.actor_optim.step()
     self.sync_weight()
     return {
         "loss/actor": actor_loss.item(),
         "loss/critic": critic_loss.item(),
     }
Пример #12
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     q = self(batch, eps=0.).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns, q).flatten()
     td = r - q
     loss = (td.pow(2) * batch.weight).mean()
     batch.weight = td  # prio-buffer
     loss.backward()
     self.optim.step()
     self._cnt += 1
     return {'loss': loss.item()}
Пример #13
0
    def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
        weight = batch.pop('weight', 1.)
        # critic 1
        current_q1 = self.critic1(batch.obs, batch.act).flatten()
        target_q = batch.returns.flatten()
        td1 = current_q1 - target_q
        critic1_loss = (td1.pow(2) * weight).mean()
        # critic1_loss = F.mse_loss(current_q1, target_q)
        self.critic1_optim.zero_grad()
        critic1_loss.backward()
        self.critic1_optim.step()
        # critic 2
        current_q2 = self.critic2(batch.obs, batch.act).flatten()
        td2 = current_q2 - target_q
        critic2_loss = (td2.pow(2) * weight).mean()
        # critic2_loss = F.mse_loss(current_q2, target_q)
        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()
        batch.weight = (td1 + td2) / 2.  # prio-buffer
        # actor
        obs_result = self(batch, explorating=False)
        a = obs_result.act
        current_q1a = self.critic1(batch.obs, a).flatten()
        current_q2a = self.critic2(batch.obs, a).flatten()
        actor_loss = (self._alpha * obs_result.log_prob.flatten()
                      - torch.min(current_q1a, current_q2a)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._automatic_alpha_tuning:
            log_prob = (obs_result.log_prob + self._target_entropy).detach()
            alpha_loss = -(self._log_alpha * log_prob).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.exp()

        self.sync_weight()

        result = {
            'loss/actor': actor_loss.item(),
            'loss/critic1': critic1_loss.item(),
            'loss/critic2': critic2_loss.item(),
        }
        if self._automatic_alpha_tuning:
            result['loss/alpha'] = alpha_loss.item()
        return result
Пример #14
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     q = self(batch).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns.flatten(), q)
     td = r - q
     loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {"loss": loss.item()}
Пример #15
0
    def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                   indice: np.ndarray) -> Batch:
        batch = super().process_fn(batch, buffer, indice)
        with torch.no_grad():
            # the degree of on-policiess
            opd = self(batch, output_dqn=False, output_opd=True).logits
            opd = opd[np.arange(len(opd)), batch.act]
            opd_weights = torch.sigmoid(opd * self.opd_temperature)
            opd_weights = opd_weights / torch.sum(opd_weights) * len(opd)
        opd_weights = opd_weights.detach().cpu().numpy()
        ori_weight = batch.pop("weight", 1.0)
        ori_weight *= opd_weights
        batch.weight = ori_weight

        return batch
Пример #16
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:

        # critic 1&2
        td1, critic1_loss = self._mse_optimizer(batch, self.critic1,
                                                self.critic1_optim,
                                                self.scaler, self.use_mixed)
        td2, critic2_loss = self._mse_optimizer(batch, self.critic2,
                                                self.critic2_optim,
                                                self.scaler, self.use_mixed)
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        with autocast(enabled=self.use_mixed):
            # actor
            obs_result = self(batch)
            a = obs_result.act
            current_q1a = self.critic1(batch.obs, a).flatten()
            current_q2a = self.critic2(batch.obs, a).flatten()
            actor_loss = (self._alpha * obs_result.log_prob.flatten() -
                          torch.min(current_q1a, current_q2a)).mean()

        self.actor_optim.zero_grad()
        self.scaler.scale(actor_loss).backward()
        self.scaler.step(self.actor_optim)
        self.scaler.update()
        # actor_loss.backward()
        # self.actor_optim.step()

        if self._is_auto_alpha:
            log_prob = obs_result.log_prob.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_prob).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

        self.sync_weight()

        result = {
            "loss/actor": actor_loss.item(),
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
        }
        if self._is_auto_alpha:
            result["loss/alpha"] = alpha_loss.item()
            result["alpha"] = self._alpha.item()  # type: ignore

        return result
Пример #17
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     with torch.no_grad():
         target_dist = self._target_dist(batch)
     weight = batch.pop("weight", 1.0)
     curr_dist = self(batch).logits
     act = batch.act
     curr_dist = curr_dist[np.arange(len(act)), act, :]
     cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1)
     loss = (cross_entropy * weight).mean()
     # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100
     batch.weight = cross_entropy.detach()  # prio-buffer
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {"loss": loss.item()}
Пример #18
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     q = self(batch).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns.flatten(), q)
     td = r - q
     loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     loss.backward()
     # Gradient clips
     torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
     self.optim.step()
     self._cnt += 1
     for param_group in self.optim.param_groups:
         lr = param_group['lr']
     return {"loss": loss.item(), "lr": lr}
Пример #19
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     weight = batch.pop("weight", 1.0)
     self.optim.zero_grad()
     q = self(batch, eps=0.).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns, q).flatten()
     c = torch.nn.SmoothL1Loss(reduction = 'none')
     # c = lambda r, q: (r-q).pow(2)
     td = c(r, q)
     loss = (td * weight).mean()
     batch.weight = loss  # prio-buffer
     loss.backward()
     if self.grad_norm_clipping:
         torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_clipping)
     self.optim.step()
     self._cnt += 1
     return {'loss': loss.item()}
Пример #20
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     current_q = self.critic(batch.obs, batch.act).flatten()
     target_q = batch.returns.flatten()
     td = current_q - target_q
     critic_loss = (td.pow(2) * batch.weight).mean()
     batch.weight = td  # prio-buffer
     self.critic_optim.zero_grad()
     critic_loss.backward()
     self.critic_optim.step()
     action = self(batch, explorating=False).act
     actor_loss = -self.critic(batch.obs, action).mean()
     self.actor_optim.zero_grad()
     actor_loss.backward()
     self.actor_optim.step()
     self.sync_weight()
     return {
         'loss/actor': actor_loss.item(),
         'loss/critic': critic_loss.item(),
     }
Пример #21
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device)
     q = self(batch).logits
     act_mask = torch.zeros_like(q)
     act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1)
     act_q = q * act_mask
     returns = batch.returns
     returns = returns * act_mask
     td_error = returns - act_q
     loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean()
     batch.weight = td_error.sum(-1).sum(-1)  # prio-buffer
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {"loss": loss.item()}
Пример #22
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        # critic ensemble
        weight = getattr(batch, "weight", 1.0)
        current_qs = self.critics(batch.obs, batch.act).flatten(1)
        target_q = batch.returns.flatten()
        td = current_qs - target_q
        critic_loss = (td.pow(2) * weight).mean()
        self.critics_optim.zero_grad()
        critic_loss.backward()
        self.critics_optim.step()
        batch.weight = torch.mean(td, dim=0)  # prio-buffer
        self.critic_gradient_step += 1

        # actor
        if self.critic_gradient_step % self.actor_delay == 0:
            obs_result = self(batch)
            a = obs_result.act
            current_qa = self.critics(batch.obs, a).mean(dim=0).flatten()
            actor_loss = (self._alpha * obs_result.log_prob.flatten() -
                          current_qa).mean()
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            if self._is_auto_alpha:
                log_prob = obs_result.log_prob.detach() + self._target_entropy
                alpha_loss = -(self._log_alpha * log_prob).mean()
                self._alpha_optim.zero_grad()
                alpha_loss.backward()
                self._alpha_optim.step()
                self._alpha = self._log_alpha.detach().exp()

        self.sync_weight()

        result = {"loss/critics": critic_loss.item()}
        if self.critic_gradient_step % self.actor_delay == 0:
            result["loss/actor"] = actor_loss.item(),
            if self._is_auto_alpha:
                result["loss/alpha"] = alpha_loss.item()
                result["alpha"] = self._alpha.item()  # type: ignore

        return result
Пример #23
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     weight = batch.pop("weight", 1.0)
     current_q = self.critic(batch.obs, batch.act).flatten()
     target_q = batch.returns.flatten()
     td = current_q - target_q
     critic_loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     self.critic_optim.zero_grad()
     critic_loss.backward()
     self.critic_optim.step()
     action = self(batch).act
     actor_loss = -self.critic(batch.obs, action).mean()
     self.actor_optim.zero_grad()
     actor_loss.backward()
     self.actor_optim.step()
     self.sync_weight()
     return {
         "loss/actor": actor_loss.item(),
         "loss/critic": critic_loss.item(),
     }
Пример #24
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        # critic 1&2
        td1, critic1_loss = self._mse_optimizer(batch, self.critic1,
                                                self.critic1_optim)
        td2, critic2_loss = self._mse_optimizer(batch, self.critic2,
                                                self.critic2_optim)
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        # actor
        obs_result = self(batch)
        act = obs_result.act
        current_q1a = self.critic1(batch.obs, act).flatten()
        current_q2a = self.critic2(batch.obs, act).flatten()
        actor_loss = (self._alpha * obs_result.log_prob.flatten() -
                      torch.min(current_q1a, current_q2a)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_prob = obs_result.log_prob.detach() + self._target_entropy
            # please take a look at issue #258 if you'd like to change this line
            alpha_loss = -(self._log_alpha * log_prob).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

        self.sync_weight()

        result = {
            "loss/actor": actor_loss.item(),
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
        }
        if self._is_auto_alpha:
            result["loss/alpha"] = alpha_loss.item()
            result["alpha"] = self._alpha.item()  # type: ignore

        return result
Пример #25
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        if self._target and self._iter % self._freq == 0:
            self.sync_weight()
        self.optim.zero_grad()
        weight = batch.pop("weight", 1.0)
        q = self(batch).logits
        q = q[np.arange(len(q)), batch.act]
        returns = to_torch_as(batch.returns.flatten(), q)
        td_error = returns - q

        if self._clip_loss_grad:
            y = q.reshape(-1, 1)
            t = returns.reshape(-1, 1)
            loss = torch.nn.functional.huber_loss(y, t, reduction="mean")
        else:
            loss = (td_error.pow(2) * weight).mean()

        batch.weight = td_error  # prio-buffer
        loss.backward()
        self.optim.step()
        self._iter += 1
        return {"loss": loss.item()}
Пример #26
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        # critic 1&2
        td1, critic1_loss = self._mse_optimizer(
            batch, self.critic1, self.critic1_optim)
        td2, critic2_loss = self._mse_optimizer(
            batch, self.critic2, self.critic2_optim)
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        # actor
        if self._cnt % self._freq == 0:
            actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean()
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self._last = actor_loss.item()
            self.actor_optim.step()
            self.sync_weight()
        self._cnt += 1
        return {
            "loss/actor": self._last,
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
        }
Пример #27
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        if self._target and self._iter % self._freq == 0:
            self.sync_weight()
        self.optim.zero_grad()

        # compute dqn loss
        weight = batch.pop("weight", 1.0)
        q = self(batch).logits
        q = q[np.arange(len(q)), batch.act]
        r = to_torch_as(batch.returns.flatten(), q)
        weight = to_torch_as(weight, q)
        td = r - q
        dqn_loss = (td.pow(2) * weight).mean()
        batch.weight = td  # prio-buffer

        # compute opd loss
        slow_preds = self(batch, output_dqn=False, output_opd=True).logits
        fast_preds = self(batch,
                          input="fast_obs",
                          output_dqn=False,
                          output_opd=True).logits
        # act_dim + 1 classes. The last class indicate the state is off-policy
        slow_label = torch.ones_like(
            slow_preds[:, 0], dtype=torch.int64) * int(self.model.output_dim)
        fast_label = to_torch_as(torch.tensor(batch.fast_act), slow_label)
        opd_loss = F.cross_entropy(slow_preds, slow_label) + \
                    F.cross_entropy(fast_preds, fast_label)

        # compute loss and back prop
        loss = dqn_loss + self.opd_loss_coeff * opd_loss
        loss.backward()
        self.optim.step()
        self._iter += 1
        return {
            "loss": loss.item(),
            "dqn_loss": dqn_loss.item(),
            "opd_loss": opd_loss.item(),
        }
Пример #28
0
    def _compute_return(
        self,
        batch: Batch,
        buffer: ReplayBuffer,
        indice: np.ndarray,
        gamma: float = 0.99,
    ) -> Batch:
        rew = batch.rew
        with torch.no_grad():
            target_q_torch = self._target_q(buffer, indice)  # (bsz, ?)
        target_q = to_numpy(target_q_torch)
        end_flag = buffer.done.copy()
        end_flag[buffer.unfinished_index()] = True
        end_flag = end_flag[indice]
        mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q
        _target_q = rew + gamma * mean_target_q * (1 - end_flag)
        target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1)
        target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1)

        batch.returns = to_torch_as(target_q, target_q_torch)
        if hasattr(batch, "weight"):  # prio buffer update
            batch.weight = to_torch_as(batch.weight, target_q_torch)
        return batch