Exemple #1
0
  def learn(self, mem):
    # Sample transitions
    # states range [0, 1]! [32, 4, 84, 84]; actions [bs=32]
    idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size)
    aug_states_1 = aug(states).to(device=self.args.device)
    aug_states_2 = aug(states).to(device=self.args.device)
    # Calculate current state probabilities (online network noise already sampled)
    log_ps, _ = self.online_net(states, log=True)  # Log probabilities log p(s_t, ·; θonline); [bs, 6, 51]
    _, z_anch = self.online_net(aug_states_1, log=True)  # shape [bs, 128]
    _, z_target = self.momentum_net(aug_states_2, log=True) # shape [bs, 128]
    z_proj = torch.matmul(self.online_net.W, z_target.T) # shape [128, bs]
    logits = torch.matmul(z_anch, z_proj) # shape [bs, bs]
    logits = (logits - torch.max(logits, 1)[0][:, None])
    logits = logits * 0.1
    labels = torch.arange(logits.shape[0]).long().to(device=self.args.device) # shape [bs]
    moco_loss = (nn.CrossEntropyLoss()(logits, labels)).to(device=self.args.device)

    log_ps_a = log_ps[range(self.batch_size), actions]  # log p(s_t, a_t; θonline); [bs, 51]

    with torch.no_grad():
      # Calculate nth NEXT state probabilities
      pns, _ = self.online_net(next_states)  # Probabilities p(s_t+n, ·; θonline)
      dns = self.support.expand_as(pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
      argmax_indices_ns = dns.sum(2).argmax(1)  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
      self.target_net.reset_noise()  # Sample new target net noise
      pns, _ = self.target_net(next_states)  # Probabilities p(s_t+n, ·; θtarget)
      pns_a = pns[range(self.batch_size), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

      # Compute Tz (Bellman operator T applied to z)
      Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
      Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  # Clamp between supported values
      # Compute L2 projection of Tz onto fixed support z
      b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
      l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
      # Fix disappearing probability mass when l = b = u (b is int)
      l[(u > 0) * (l == u)] -= 1
      u[(l < (self.atoms - 1)) * (l == u)] += 1

      # Distribute probability of Tz
      m = states.new_zeros(self.batch_size, self.atoms)
      offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions)
      m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
      m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

    loss = -torch.sum(m * log_ps_a, 1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
    loss = loss + (moco_loss * self.coeff)
    self.online_net.zero_grad()
    curl_loss = (weights * loss).mean()
    curl_loss.mean().backward()  # Backpropagate importance-weighted minibatch loss
    clip_grad_norm_(self.online_net.parameters(), self.norm_clip)  # Clip gradients by L2 norm
    self.optimiser.step()

    mem.update_priorities(idxs, loss.detach().cpu().numpy())  # Update priorities of sampled transitions
Exemple #2
0
    def forward(
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        model: str = "model",
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        """Compute action over the given batch data.

        If you need to mask the action, please add a "mask" into batch.obs, for
        example, if we have an environment that has "0/1/2" three actions:
        ::

            batch == Batch(
                obs=Batch(
                    obs="original obs, with batch_size=1 for demonstration",
                    mask=np.array([[False, True, False]]),
                    # action 1 is available
                    # action 0 and 2 are unavailable
                ),
                ...
            )

        :param float eps: in [0, 1], for epsilon-greedy exploration method.

        :return: A :class:`~tianshou.data.Batch` which has 3 keys:

            * ``act`` the action.
            * ``logits`` the network's raw output.
            * ``state`` the hidden state.

        .. seealso::

            Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
            more detailed explanation.
        """
        model = getattr(self, model)
        momentum_model = self.momentum_model
        obs = batch[input]
        obs_ = obs.obs if hasattr(obs, "obs") else obs
        # augment
        obs_ = torch.as_tensor(obs_, dtype=torch.float32)

        obs_1 = aug(obs_)
        obs_2 = aug(obs_)

        logits, _, h = model(obs_, state=state, info=batch.info)
        _, z_anch, _ = model(obs_1)
        _, z_target, _ = momentum_model(obs_2)
        z_proj = torch.matmul(self.model.W, z_target.T)
        logits = torch.matmul(z_anch, z_proj)
        logits = (logits - torch.max(logits, 1)[0][:, None])
        logits = logits * 0.1
        labels = torch.arange(logits.shape[0]).long()
        moco_loss = (nn.CrossEntropyLoss()(logits, labels))

        q = self.compute_q_value(logits, getattr(obs, "mask", None))
        if not hasattr(self, "max_action_num"):
            self.max_action_num = q.shape[1]
        act = to_numpy(q.max(dim=1)[1])
        return Batch(logits=logits, moco_loss=moco_loss, act=act, state=h)