def learn(self, mem): # Sample transitions # states range [0, 1]! [32, 4, 84, 84]; actions [bs=32] idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size) aug_states_1 = aug(states).to(device=self.args.device) aug_states_2 = aug(states).to(device=self.args.device) # Calculate current state probabilities (online network noise already sampled) log_ps, _ = self.online_net(states, log=True) # Log probabilities log p(s_t, ·; θonline); [bs, 6, 51] _, z_anch = self.online_net(aug_states_1, log=True) # shape [bs, 128] _, z_target = self.momentum_net(aug_states_2, log=True) # shape [bs, 128] z_proj = torch.matmul(self.online_net.W, z_target.T) # shape [128, bs] logits = torch.matmul(z_anch, z_proj) # shape [bs, bs] logits = (logits - torch.max(logits, 1)[0][:, None]) logits = logits * 0.1 labels = torch.arange(logits.shape[0]).long().to(device=self.args.device) # shape [bs] moco_loss = (nn.CrossEntropyLoss()(logits, labels)).to(device=self.args.device) log_ps_a = log_ps[range(self.batch_size), actions] # log p(s_t, a_t; θonline); [bs, 51] with torch.no_grad(): # Calculate nth NEXT state probabilities pns, _ = self.online_net(next_states) # Probabilities p(s_t+n, ·; θonline) dns = self.support.expand_as(pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) argmax_indices_ns = dns.sum(2).argmax(1) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] self.target_net.reset_noise() # Sample new target net noise pns, _ = self.target_net(next_states) # Probabilities p(s_t+n, ·; θtarget) pns_a = pns[range(self.batch_size), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) # Compute Tz (Bellman operator T applied to z) Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz m = states.new_zeros(self.batch_size, self.atoms) offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions) m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) loss = -torch.sum(m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) loss = loss + (moco_loss * self.coeff) self.online_net.zero_grad() curl_loss = (weights * loss).mean() curl_loss.mean().backward() # Backpropagate importance-weighted minibatch loss clip_grad_norm_(self.online_net.parameters(), self.norm_clip) # Clip gradients by L2 norm self.optimiser.step() mem.update_priorities(idxs, loss.detach().cpu().numpy()) # Update priorities of sampled transitions
def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = "model", input: str = "obs", **kwargs: Any, ) -> Batch: """Compute action over the given batch data. If you need to mask the action, please add a "mask" into batch.obs, for example, if we have an environment that has "0/1/2" three actions: :: batch == Batch( obs=Batch( obs="original obs, with batch_size=1 for demonstration", mask=np.array([[False, True, False]]), # action 1 is available # action 0 and 2 are unavailable ), ... ) :param float eps: in [0, 1], for epsilon-greedy exploration method. :return: A :class:`~tianshou.data.Batch` which has 3 keys: * ``act`` the action. * ``logits`` the network's raw output. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) momentum_model = self.momentum_model obs = batch[input] obs_ = obs.obs if hasattr(obs, "obs") else obs # augment obs_ = torch.as_tensor(obs_, dtype=torch.float32) obs_1 = aug(obs_) obs_2 = aug(obs_) logits, _, h = model(obs_, state=state, info=batch.info) _, z_anch, _ = model(obs_1) _, z_target, _ = momentum_model(obs_2) z_proj = torch.matmul(self.model.W, z_target.T) logits = torch.matmul(z_anch, z_proj) logits = (logits - torch.max(logits, 1)[0][:, None]) logits = logits * 0.1 labels = torch.arange(logits.shape[0]).long() moco_loss = (nn.CrossEntropyLoss()(logits, labels)) q = self.compute_q_value(logits, getattr(obs, "mask", None)) if not hasattr(self, "max_action_num"): self.max_action_num = q.shape[1] act = to_numpy(q.max(dim=1)[1]) return Batch(logits=logits, moco_loss=moco_loss, act=act, state=h)