def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() batch.weight = (td1 + td2) / 2.0 # prio-buffer if self._cnt % self._freq == 0: actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean() self.actor_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() self.actor_optim.step() self.sync_weight() self._cnt += 1 return { "loss/actor": self._last, "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic1, self.critic1_optim, self.scaler, self.use_mixed) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim, self.scaler, self.use_mixed) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor if self._cnt % self._freq == 0: actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean() self.actor_optim.zero_grad() self.scaler.scale(actor_loss).backward() # actor_loss.backward() self._last = actor_loss.item() self.scaler.step(self.actor_optim) # self.actor_optim.step() self.sync_weight() self.scaler.update( ) # Check this if this is correct, with sync_weight above as well self._cnt += 1 return { "loss/actor": self._last, "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic 1&2 td1, critic1_loss = self._mse_optimizer( batch, self.critic1, self.critic1_optim ) td2, critic2_loss = self._mse_optimizer( batch, self.critic2, self.critic2_optim ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor if self._cnt % self._freq == 0: act = self(batch, eps=0.0).act q_value = self.critic1(batch.obs, act) lmbda = self._alpha / q_value.abs().mean().detach() actor_loss = -lmbda * q_value.mean() + F.mse_loss( act, to_torch_as(batch.act, act) ) self.actor_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() self.actor_optim.step() self.sync_weight() self._cnt += 1 return { "loss/actor": self._last, "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) all_dist = self(batch).logits act = to_torch(batch.act, dtype=torch.long, device=all_dist.device) curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( u * (self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs() ).sum(-1).mean(1) qr_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer # add CQL loss q = self.compute_q_value(all_dist, None) dataset_expec = q.gather(1, act.unsqueeze(1)).mean() negative_sampling = q.logsumexp(1).mean() min_q_loss = negative_sampling - dataset_expec loss = qr_loss + min_q_loss * self._min_q_weight loss.backward() self.optim.step() self._iter += 1 return { "loss": loss.item(), "loss/qr": qr_loss.item(), "loss/cql": min_q_loss.item(), }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) action_batch = self(batch) curr_dist, taus = action_batch.logits, action_batch.taus act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( dist_diff * (taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float()).abs() ).sum(-1).mean(1) loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) # critic 1 current_q1 = self.critic1(batch.obs).gather(1, act).flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs).gather(1, act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor dist = self(batch).dist entropy = dist.entropy() with torch.no_grad(): current_q1a = self.critic1(batch.obs) current_q2a = self.critic2(batch.obs) q = torch.min(current_q1a, current_q2a) actor_loss = -(self._alpha * entropy + (dist.probs * q).sum(dim=-1)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if self._is_auto_alpha: log_prob = -entropy.detach() + self._target_entropy alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.detach().exp() self.sync_weight() result = { "loss/actor": actor_loss.item(), "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), } if self._is_auto_alpha: result["loss/alpha"] = alpha_loss.item() result["alpha"] = self._alpha.item() # type: ignore return result
def compute_nstep_return( batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, rew_norm: bool = False, ) -> Batch: r"""Compute n-step return for Q-learning targets. .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. :param batch: a data batch, which is equal to buffer[indice]. :type batch: :class:`~tianshou.data.Batch` :param buffer: a data buffer which contains several full-episode data chronologically. :type buffer: :class:`~tianshou.data.ReplayBuffer` :param indice: sampled timestep. :type indice: numpy.ndarray :param function target_q_fn: a function receives :math:`t+n-1` step's data and compute target Q value. :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param int n_step: the number of estimation step, should be an int greater than 0, defaults to 1. :param bool rew_norm: normalize the reward to Normal(0, 1), defaults to False. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with shape (bsz, ). """ rew = buffer.rew if rew_norm: bfr = rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() if np.isclose(std, 0, 1e-2): mean, std = 0.0, 1.0 else: mean, std = 0.0, 1.0 buf_len = len(buffer) terminal = (indice + n_step - 1) % buf_len target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, ) target_q = to_numpy(target_q_torch) target_q = _nstep_return(rew, buffer.done, target_q, indice, gamma, n_step, len(buffer), mean, std) batch.returns = to_torch_as(target_q, target_q_torch) # prio buffer update if isinstance(buffer, PrioritizedReplayBuffer): batch.weight = to_torch_as(batch.weight, target_q_torch) return batch
def compute_nstep_return( batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, rew_norm: bool = False, use_mixed: bool = False, ) -> Batch: r"""Compute n-step return for Q-learning targets. .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. :param Batch batch: a data batch, which is equal to buffer[indice]. :param ReplayBuffer buffer: the data buffer. :param function target_q_fn: a function which compute target Q value of "obs_next" given data buffer and wanted indices. :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. :param int n_step: the number of estimation step, should be an int greater than 0. Default to 1. :param bool rew_norm: normalize the reward to Normal(0, 1), Default to False. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with the same shape as target_q_fn's return tensor. """ assert not rew_norm, \ "Reward normalization in computing n-step returns is unsupported now." rew = buffer.rew bsz = len(indice) indices = [indice] for _ in range(n_step - 1): indices.append(buffer.next(indices[-1])) indices = np.stack(indices) # terminal indicates buffer indexes nstep after 'indice', # and are truncated at the end of each episode terminal = indices[-1] with autocast(enabled=use_mixed): with torch.no_grad(): target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) target_q = to_numpy(target_q_torch.float().reshape(bsz, -1)) target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape( -1, 1) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True target_q = _nstep_return(rew, end_flag, target_q, indices, gamma, n_step) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) return batch
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor obs_result = self(batch) a = obs_result.act current_q1a = self.critic1(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, a).flatten() actor_loss = (self._alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if self._is_auto_alpha: log_prob = obs_result.log_prob.detach() + self._target_entropy alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.detach().exp() self.sync_weight() result = { "loss/actor": actor_loss.item(), "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), } if self._is_auto_alpha: result["loss/alpha"] = alpha_loss.item() result["alpha"] = self._alpha.item() # type: ignore return result
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) out = self(batch) curr_dist_orig = out.logits taus, tau_hats = out.fractions.taus, out.fractions.tau_hats act = batch.act curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( dist_diff * (tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float()).abs() ).sum(-1).mean(1) quantile_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer # calculate fraction loss with torch.no_grad(): sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :] sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :] # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169 values_1 = sa_quantiles - sa_quantile_hats[:, :-1] signs_1 = sa_quantiles > torch.cat( [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1) values_2 = sa_quantiles - sa_quantile_hats[:, 1:] signs_2 = sa_quantiles < torch.cat( [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1) gradient_of_taus = (torch.where(signs_1, values_1, -values_1) + torch.where(signs_2, values_2, -values_2)) fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean() # calculate entropy loss entropy_loss = out.fractions.entropies.mean() fraction_entropy_loss = fraction_loss - self._ent_coef * entropy_loss self._fraction_optim.zero_grad() fraction_entropy_loss.backward(retain_graph=True) self._fraction_optim.step() self.optim.zero_grad() quantile_loss.backward() self.optim.step() self._iter += 1 return { "loss": quantile_loss.item() + fraction_entropy_loss.item(), "loss/quantile": quantile_loss.item(), "loss/fraction": fraction_loss.item(), "loss/entropy": entropy_loss.item() }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer # actor actor_loss = -self.critic(batch.obs, self(batch).act).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { "loss/actor": actor_loss.item(), "loss/critic": critic_loss.item(), }
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() td = r - q loss = (td.pow(2) * batch.weight).mean() batch.weight = td # prio-buffer loss.backward() self.optim.step() self._cnt += 1 return {'loss': loss.item()}
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: weight = batch.pop('weight', 1.) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() batch.weight = (td1 + td2) / 2. # prio-buffer # actor obs_result = self(batch, explorating=False) a = obs_result.act current_q1a = self.critic1(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, a).flatten() actor_loss = (self._alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if self._automatic_alpha_tuning: log_prob = (obs_result.log_prob + self._target_entropy).detach() alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.exp() self.sync_weight() result = { 'loss/actor': actor_loss.item(), 'loss/critic1': critic1_loss.item(), 'loss/critic2': critic2_loss.item(), } if self._automatic_alpha_tuning: result['loss/alpha'] = alpha_loss.item() return result
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns.flatten(), q) td = r - q loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: batch = super().process_fn(batch, buffer, indice) with torch.no_grad(): # the degree of on-policiess opd = self(batch, output_dqn=False, output_opd=True).logits opd = opd[np.arange(len(opd)), batch.act] opd_weights = torch.sigmoid(opd * self.opd_temperature) opd_weights = opd_weights / torch.sum(opd_weights) * len(opd) opd_weights = opd_weights.detach().cpu().numpy() ori_weight = batch.pop("weight", 1.0) ori_weight *= opd_weights batch.weight = ori_weight return batch
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic1, self.critic1_optim, self.scaler, self.use_mixed) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim, self.scaler, self.use_mixed) batch.weight = (td1 + td2) / 2.0 # prio-buffer with autocast(enabled=self.use_mixed): # actor obs_result = self(batch) a = obs_result.act current_q1a = self.critic1(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, a).flatten() actor_loss = (self._alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() self.scaler.scale(actor_loss).backward() self.scaler.step(self.actor_optim) self.scaler.update() # actor_loss.backward() # self.actor_optim.step() if self._is_auto_alpha: log_prob = obs_result.log_prob.detach() + self._target_entropy alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.detach().exp() self.sync_weight() result = { "loss/actor": actor_loss.item(), "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), } if self._is_auto_alpha: result["loss/alpha"] = alpha_loss.item() result["alpha"] = self._alpha.item() # type: ignore return result
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() with torch.no_grad(): target_dist = self._target_dist(batch) weight = batch.pop("weight", 1.0) curr_dist = self(batch).logits act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :] cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1) loss = (cross_entropy * weight).mean() # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100 batch.weight = cross_entropy.detach() # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns.flatten(), q) td = r - q loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer loss.backward() # Gradient clips torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optim.step() self._cnt += 1 for param_group in self.optim.param_groups: lr = param_group['lr'] return {"loss": loss.item(), "lr": lr}
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) self.optim.zero_grad() q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() c = torch.nn.SmoothL1Loss(reduction = 'none') # c = lambda r, q: (r-q).pow(2) td = c(r, q) loss = (td * weight).mean() batch.weight = loss # prio-buffer loss.backward() if self.grad_norm_clipping: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_clipping) self.optim.step() self._cnt += 1 return {'loss': loss.item()}
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * batch.weight).mean() batch.weight = td # prio-buffer self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() action = self(batch, explorating=False).act actor_loss = -self.critic(batch.obs, action).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { 'loss/actor': actor_loss.item(), 'loss/critic': critic_loss.item(), }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) q = self(batch).logits act_mask = torch.zeros_like(q) act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1) act_q = q * act_mask returns = batch.returns returns = returns * act_mask td_error = returns - act_q loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean() batch.weight = td_error.sum(-1).sum(-1) # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic ensemble weight = getattr(batch, "weight", 1.0) current_qs = self.critics(batch.obs, batch.act).flatten(1) target_q = batch.returns.flatten() td = current_qs - target_q critic_loss = (td.pow(2) * weight).mean() self.critics_optim.zero_grad() critic_loss.backward() self.critics_optim.step() batch.weight = torch.mean(td, dim=0) # prio-buffer self.critic_gradient_step += 1 # actor if self.critic_gradient_step % self.actor_delay == 0: obs_result = self(batch) a = obs_result.act current_qa = self.critics(batch.obs, a).mean(dim=0).flatten() actor_loss = (self._alpha * obs_result.log_prob.flatten() - current_qa).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if self._is_auto_alpha: log_prob = obs_result.log_prob.detach() + self._target_entropy alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.detach().exp() self.sync_weight() result = {"loss/critics": critic_loss.item()} if self.critic_gradient_step % self.actor_delay == 0: result["loss/actor"] = actor_loss.item(), if self._is_auto_alpha: result["loss/alpha"] = alpha_loss.item() result["alpha"] = self._alpha.item() # type: ignore return result
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() action = self(batch).act actor_loss = -self.critic(batch.obs, action).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { "loss/actor": actor_loss.item(), "loss/critic": critic_loss.item(), }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic1, self.critic1_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor obs_result = self(batch) act = obs_result.act current_q1a = self.critic1(batch.obs, act).flatten() current_q2a = self.critic2(batch.obs, act).flatten() actor_loss = (self._alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if self._is_auto_alpha: log_prob = obs_result.log_prob.detach() + self._target_entropy # please take a look at issue #258 if you'd like to change this line alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.detach().exp() self.sync_weight() result = { "loss/actor": actor_loss.item(), "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), } if self._is_auto_alpha: result["loss/alpha"] = alpha_loss.item() result["alpha"] = self._alpha.item() # type: ignore return result
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] returns = to_torch_as(batch.returns.flatten(), q) td_error = returns - q if self._clip_loss_grad: y = q.reshape(-1, 1) t = returns.reshape(-1, 1) loss = torch.nn.functional.huber_loss(y, t, reduction="mean") else: loss = (td_error.pow(2) * weight).mean() batch.weight = td_error # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic 1&2 td1, critic1_loss = self._mse_optimizer( batch, self.critic1, self.critic1_optim) td2, critic2_loss = self._mse_optimizer( batch, self.critic2, self.critic2_optim) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor if self._cnt % self._freq == 0: actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean() self.actor_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() self.actor_optim.step() self.sync_weight() self._cnt += 1 return { "loss/actor": self._last, "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() # compute dqn loss weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns.flatten(), q) weight = to_torch_as(weight, q) td = r - q dqn_loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer # compute opd loss slow_preds = self(batch, output_dqn=False, output_opd=True).logits fast_preds = self(batch, input="fast_obs", output_dqn=False, output_opd=True).logits # act_dim + 1 classes. The last class indicate the state is off-policy slow_label = torch.ones_like( slow_preds[:, 0], dtype=torch.int64) * int(self.model.output_dim) fast_label = to_torch_as(torch.tensor(batch.fast_act), slow_label) opd_loss = F.cross_entropy(slow_preds, slow_label) + \ F.cross_entropy(fast_preds, fast_label) # compute loss and back prop loss = dqn_loss + self.opd_loss_coeff * opd_loss loss.backward() self.optim.step() self._iter += 1 return { "loss": loss.item(), "dqn_loss": dqn_loss.item(), "opd_loss": opd_loss.item(), }
def _compute_return( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, gamma: float = 0.99, ) -> Batch: rew = batch.rew with torch.no_grad(): target_q_torch = self._target_q(buffer, indice) # (bsz, ?) target_q = to_numpy(target_q_torch) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True end_flag = end_flag[indice] mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q _target_q = rew + gamma * mean_target_q * (1 - end_flag) target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1) target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) return batch