def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._rew_norm: mean, std = batch.rew.mean(), batch.rew.std() if not np.isclose(std, 0, 1e-2): batch.rew = (batch.rew - mean) / std v, v_, old_log_prob = [], [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False): v_.append(self.critic(b.obs_next)) v.append(self.critic(b.obs)) old_log_prob.append(self(b).dist.log_prob( to_torch_as(b.act, v[0]))) v_ = to_numpy(torch.cat(v_, dim=0)) batch = self.compute_episodic_return( batch, v_, gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm) batch.v = torch.cat(v, dim=0).flatten() # old value batch.act = to_torch_as(batch.act, v[0]) batch.logp_old = torch.cat(old_log_prob, dim=0) batch.returns = to_torch_as(batch.returns, v[0]) batch.adv = batch.returns - batch.v if self._rew_norm: mean, std = batch.adv.mean(), batch.adv.std() if not np.isclose(std.item(), 0, 1e-2): batch.adv = (batch.adv - mean) / std return batch
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() dist = self(b).dist v = self.critic(b.obs).flatten() a = to_torch_as(b.act, v) r = to_torch_as(b.returns, v) log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) a_loss = -(log_prob * (r - v).detach()).mean() vf_loss = F.mse_loss(r, v) # type: ignore ent_loss = dist.entropy().mean() loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss.backward() if self._grad_norm is not None: nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm, ) self.optim.step() actor_losses.append(a_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) return { "loss": losses, "loss/actor": actor_losses, "loss/vf": vf_losses, "loss/ent": ent_losses, }
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: self._batch = batch_size r = batch.returns if self._rew_norm and not np.isclose(r.std(), 0): batch.returns = (r - r.mean()) / r.std() losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist v = self.critic(b.obs).squeeze(-1) a = to_torch_as(b.act, v) r = to_torch_as(b.returns, v) a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach()).mean() vf_loss = F.mse_loss(r, v) ent_loss = dist.entropy().mean() loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss.backward() if self._grad_norm is not None: nn.utils.clip_grad_norm_(list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm) self.optim.step() actor_losses.append(a_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) return { 'loss': losses, 'loss/actor': actor_losses, 'loss/vf': vf_losses, 'loss/ent': ent_losses, }
def _compute_returns(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: v_s, v_s_ = [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_s.append(self.critic(b.obs)) v_s_.append(self.critic(b.obs_next)) batch.v_s = torch.cat(v_s, dim=0).flatten() # old value v_s = batch.v_s.cpu().numpy() v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy() # when normalizing values, we do not minus self.ret_rms.mean to be numerically # consistent with OPENAI baselines' value normalization pipeline. Emperical # study also shows that "minus mean" will harm performances a tiny little bit # due to unknown reasons (on Mujoco envs, not confident, though). if self._rew_norm: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) unnormalized_returns, advantages = self.compute_episodic_return( batch, buffer, indice, v_s_, v_s, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: batch.returns = unnormalized_returns / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns batch.returns = to_torch_as(batch.returns, batch.v_s) batch.adv = to_torch_as(advantages, batch.v_s) return batch
def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: v_s, v_s_, old_log_prob = [], [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_s.append(self.critic(b.obs)) v_s_.append(self.critic(b.obs_next)) old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0]))) batch.v_s = torch.cat(v_s, dim=0).flatten() # old value v_s = to_numpy(batch.v_s) v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten()) if self._rew_norm: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean unnormalized_returns, advantages = self.compute_episodic_return( batch, buffer, indice, v_s_, v_s, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) mean, std = np.mean(advantages), np.std(advantages) advantages = (advantages - mean) / std # per-batch norm else: batch.returns = unnormalized_returns batch.act = to_torch_as(batch.act, batch.v_s) batch.logp_old = torch.cat(old_log_prob, dim=0) batch.returns = to_torch_as(batch.returns, batch.v_s) batch.adv = to_torch_as(advantages, batch.v_s) return batch
def compute_nstep_return( batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, rew_norm: bool = False, ) -> Batch: r"""Compute n-step return for Q-learning targets. .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. :param batch: a data batch, which is equal to buffer[indice]. :type batch: :class:`~tianshou.data.Batch` :param buffer: a data buffer which contains several full-episode data chronologically. :type buffer: :class:`~tianshou.data.ReplayBuffer` :param indice: sampled timestep. :type indice: numpy.ndarray :param function target_q_fn: a function receives :math:`t+n-1` step's data and compute target Q value. :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param int n_step: the number of estimation step, should be an int greater than 0, defaults to 1. :param bool rew_norm: normalize the reward to Normal(0, 1), defaults to False. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with shape (bsz, ). """ rew = buffer.rew if rew_norm: bfr = rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() if np.isclose(std, 0, 1e-2): mean, std = 0.0, 1.0 else: mean, std = 0.0, 1.0 buf_len = len(buffer) terminal = (indice + n_step - 1) % buf_len target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, ) target_q = to_numpy(target_q_torch) target_q = _nstep_return(rew, buffer.done, target_q, indice, gamma, n_step, len(buffer), mean, std) batch.returns = to_torch_as(target_q, target_q_torch) # prio buffer update if isinstance(buffer, PrioritizedReplayBuffer): batch.weight = to_torch_as(batch.weight, target_q_torch) return batch
def compute_nstep_return(batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, rew_norm: bool = False) -> np.ndarray: r"""Compute n-step return for Q-learning targets: .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) , where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. :param batch: a data batch, which is equal to buffer[indice]. :type batch: :class:`~tianshou.data.Batch` :param buffer: a data buffer which contains several full-episode data chronologically. :type buffer: :class:`~tianshou.data.ReplayBuffer` :param indice: sampled timestep. :type indice: numpy.ndarray :param function target_q_fn: a function receives :math:`t+n-1` step's data and compute target Q value. :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param int n_step: the number of estimation step, should be an int greater than 0, defaults to 1. :param bool rew_norm: normalize the reward to Normal(0, 1), defaults to ``False``. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with shape (bsz, ). """ if rew_norm: bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() if np.isclose(std, 0): mean, std = 0, 1 else: mean, std = 0, 1 returns = np.zeros_like(indice) gammas = np.zeros_like(indice) + n_step done, rew, buf_len = buffer.done, buffer.rew, len(buffer) for n in range(n_step - 1, -1, -1): now = (indice + n) % buf_len gammas[done[now] > 0] = n returns[done[now] > 0] = 0 returns = (rew[now] - mean) / std + gamma * returns terminal = (indice + n_step - 1) % buf_len target_q = target_q_fn(buffer, terminal).squeeze() target_q[gammas != n_step] = 0 returns = to_torch_as(returns, target_q) gammas = to_torch_as(gamma**gammas, target_q) batch.returns = target_q * gammas + returns return batch
def compute_nstep_return( batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, rew_norm: bool = False, use_mixed: bool = False, ) -> Batch: r"""Compute n-step return for Q-learning targets. .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. :param Batch batch: a data batch, which is equal to buffer[indice]. :param ReplayBuffer buffer: the data buffer. :param function target_q_fn: a function which compute target Q value of "obs_next" given data buffer and wanted indices. :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. :param int n_step: the number of estimation step, should be an int greater than 0. Default to 1. :param bool rew_norm: normalize the reward to Normal(0, 1), Default to False. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with the same shape as target_q_fn's return tensor. """ assert not rew_norm, \ "Reward normalization in computing n-step returns is unsupported now." rew = buffer.rew bsz = len(indice) indices = [indice] for _ in range(n_step - 1): indices.append(buffer.next(indices[-1])) indices = np.stack(indices) # terminal indicates buffer indexes nstep after 'indice', # and are truncated at the end of each episode terminal = indices[-1] with autocast(enabled=use_mixed): with torch.no_grad(): target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) target_q = to_numpy(target_q_torch.float().reshape(bsz, -1)) target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape( -1, 1) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True target_q = _nstep_return(rew, end_flag, target_q, indices, gamma, n_step) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) return batch
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: losses = [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() dist = self(b).dist a = to_torch_as(b.act, dist.logits) r = to_torch_as(b.returns, dist.logits) log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) loss = -(log_prob * r).mean() loss.backward() self.optim.step() losses.append(loss.item()) return {"loss": losses}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns.flatten(), q) weight = to_torch_as(weight, q) td = r - q loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = 'actor', input: str = 'obs', explorating: bool = True, **kwargs) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 2 keys: * ``act`` the action. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) obs = getattr(batch, input) logits, h = model(obs, state=state, info=batch.info) actions = torch.tanh(logits) if self.training and explorating: actions = actions + to_torch_as(self._noise(actions.shape), actions) actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h)
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: act = logits[0] else: act = dist.rsample() log_prob = dist.log_prob(act).unsqueeze(-1) # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. if self.action_scaling and self.action_space is not None: action_scale = to_torch_as( (self.action_space.high - self.action_space.low) / 2.0, act) else: action_scale = 1.0 # type: ignore squashed_action = torch.tanh(act) log_prob = log_prob - torch.log(action_scale * (1 - squashed_action.pow(2)) + self.__eps).sum(-1, keepdim=True) return Batch(logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob)
def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = "actor", input: str = "obs", **kwargs: Any, ) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 2 keys: * ``act`` the action. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) obs = batch[input] actions, h = model(obs, state=state, info=batch.info) actions += self._action_bias if self._noise and not self.updating: actions += to_torch_as(self._noise(actions.shape), actions) actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 1 current_q1 = self.critic1(batch.obs, batch.act) target_q = to_torch_as(batch.returns, current_q1)[:, None] critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act) critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() # actor obs_result = self(batch) a = obs_result.act current_q1a = self.critic1(batch.obs, a) current_q2a = self.critic2(batch.obs, a) actor_loss = (self._alpha * obs_result.log_prob - torch.min(current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { 'loss/actor': actor_loss.item(), 'loss/critic1': critic1_loss.item(), 'loss/critic2': critic2_loss.item(), }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic 1&2 td1, critic1_loss = self._mse_optimizer( batch, self.critic1, self.critic1_optim ) td2, critic2_loss = self._mse_optimizer( batch, self.critic2, self.critic2_optim ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor if self._cnt % self._freq == 0: act = self(batch, eps=0.0).act q_value = self.critic1(batch.obs, act) lmbda = self._alpha / q_value.abs().mean().detach() actor_loss = -lmbda * q_value.mean() + F.mse_loss( act, to_torch_as(batch.act, act) ) self.actor_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() self.actor_optim.step() self.sync_weight() self._cnt += 1 return { "loss/actor": self._last, "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), }
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: x = logits[0] else: x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) if self._noise is not None and self.training and not self.updating: act += to_torch_as(self._noise(act.shape), act) act = act.clamp(self._range[0], self._range[1]) return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def compute_q_value(self, logits: torch.Tensor, mask: Optional[np.ndarray]) -> torch.Tensor: """Compute the q value based on the network's raw output and action mask.""" if mask is not None: # the masked q value should be smaller than logits.min() min_value = logits.min() - logits.max() - 1.0 logits = logits + to_torch_as(1 - mask, logits) * min_value return logits
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: losses = [] r = batch.returns if self._rew_norm and not np.isclose(r.std(), 0): batch.returns = (r - r.mean()) / r.std() for _ in range(repeat): for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist a = to_torch_as(b.act, dist.logits) r = to_torch_as(b.returns, dist.logits) loss = -(dist.log_prob(a) * r).sum() loss.backward() self.optim.step() losses.append(loss.item()) return {'loss': losses}
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q) if hasattr(batch, 'update_weight'): td = r - q batch.update_weight(batch.indice, to_numpy(td)) impt_weight = to_torch_as(batch.impt_weight, q) loss = (td.pow(2) * impt_weight).mean() else: loss = F.mse_loss(q, r) loss.backward() self.optim.step() self._cnt += 1 return {'loss': loss.item()}
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} with torch.no_grad(): obs_next_result = self(batch, input='obs_next', explorating=False) a_ = obs_next_result.act batch.act = to_torch_as(batch.act, a_) target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_), ) - self._alpha * obs_next_result.log_prob return target_q
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: losses = [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() result = self(b) dist = result.dist a = to_torch_as(b.act, result.act) r = to_torch_as(b.returns, result.act) log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) loss = -(log_prob * r).mean() loss.backward() self.optim.step() losses.append(loss.item()) # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() return {"loss": losses}
def learn(self, batch:Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, float]: for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() state_values = self.value_net(b.obs, b.act).flatten() advantages = b.advantages.flatten() #actions = b.act dist = self(b).dist actions = to_torch_as(b.act, dist.logits) rewards = to_torch_as(b.returns, dist.logits) # Advantage estimation adv = advantages - state_values adv_squared = (adv.pow(2)).mean() # Value loss v_loss = 0.5 * adv_squared # Policy loss # Update averaged advantage norm # Exponentially weighted advantages exp_advs = # log\pi_\theta(a|s) log_prob = dist.log_prob(a).reshape(len(rewards), -1).transpose(0, 1) p_loss = - 1.0 * (log_prob * exp_advs.detach()).mean() # Combine both losses loss = p_loss + self.vf_coeff * v_loss loss.backward() self.optim.step() return { "policy_loss": p_loss.item(), "vf_loss": v_loss.item(), "total_loss": loss.item(), "vf_explained_var": explained_variance.item() }
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._recompute_adv: # buffer input `buffer` and `indice` to be used in `learn()`. self._buffer, self._indice = buffer, indice batch = self._compute_returns(batch, buffer, indice) batch.act = to_torch_as(batch.act, batch.v_s) old_log_prob = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): old_log_prob.append(self(b).dist.log_prob(b.act)) batch.logp_old = torch.cat(old_log_prob, dim=0) return batch
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() # compute dqn loss weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns.flatten(), q) weight = to_torch_as(weight, q) td = r - q dqn_loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer # compute opd loss slow_preds = self(batch, output_dqn=False, output_opd=True).logits fast_preds = self(batch, input="fast_obs", output_dqn=False, output_opd=True).logits # act_dim + 1 classes. The last class indicate the state is off-policy slow_label = torch.ones_like( slow_preds[:, 0], dtype=torch.int64) * int(self.model.output_dim) fast_label = to_torch_as(torch.tensor(batch.fast_act), slow_label) opd_loss = F.cross_entropy(slow_preds, slow_label) + \ F.cross_entropy(fast_preds, fast_label) # compute loss and back prop loss = dqn_loss + self.opd_loss_coeff * opd_loss loss.backward() self.optim.step() self._iter += 1 return { "loss": loss.item(), "dqn_loss": dqn_loss.item(), "opd_loss": opd_loss.item(), }
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop('weight', 1.) q = self(batch, eps=0., is_learning=True).logits # q = torch.tensor(q, requires_grad=True) # q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.policy.reshape((-1, 1)), q) loss = F.mse_loss(r, q) loss.backward() self.optim.step() self._cnt += 1 return {'loss': loss.item()}
def _compute_return( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, gamma: float = 0.99, ) -> Batch: rew = batch.rew with torch.no_grad(): target_q_torch = self._target_q(buffer, indice) # (bsz, ?) target_q = to_numpy(target_q_torch) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True end_flag = end_flag[indice] mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q _target_q = rew + gamma * mean_target_q * (1 - end_flag) target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1) target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) return batch
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() td = r - q loss = (td.pow(2) * batch.weight).mean() batch.weight = td # prio-buffer loss.backward() self.optim.step() self._cnt += 1 return {'loss': loss.item()}
def forward(self, s, a=None): """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. assert len(s.shape) == 3 self.nn.flatten_parameters() s, (h, c) = self.nn(s) s = s[:, -1] if a is not None: a = to_torch_as(a, s) s = torch.cat([s, a], dim=1) s = self.fc2(s) return s
def forward( self, s: Union[np.ndarray, torch.Tensor], a: Optional[Union[np.ndarray, torch.Tensor]] = None, info: Dict[str, Any] = {}, ) -> torch.Tensor: """Mapping: (s, a) -> logits -> Q(s, a).""" s = to_torch(s, device=self.device, dtype=torch.float32) s = s.flatten(1) if a is not None: a = to_torch_as(a, s) a = a.flatten(1) s = torch.cat([s, a], dim=1) logits, h = self.preprocess(s) logits = self.last(logits) return logits
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: current_q = self.critic(batch.obs, batch.act) target_q = to_torch_as(batch.returns, current_q) target_q = target_q[:, None] critic_loss = F.mse_loss(current_q, target_q) self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() actor_loss = -self.critic(batch.obs, self(batch, eps=0).act).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { 'loss/actor': actor_loss.item(), 'loss/critic': critic_loss.item(), }