def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: actor_losses, vf_losses, kls = [], [], [] for step in range(repeat): for b in batch.split(batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient dist = self(b).dist # TODO could come from batch ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) actor_loss = -(ratio * b.adv).mean() flat_grads = self._get_flat_grad(actor_loss, self.actor, retain_graph=True).detach() # direction: calculate natural gradient with torch.no_grad(): old_dist = self(b).dist kl = kl_divergence(old_dist, dist).mean() # calculate first order gradient of kl with respect to theta flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) search_direction = -self._conjugate_gradients( flat_grads, flat_kl_grad, nsteps=10) # step with torch.no_grad(): flat_params = torch.cat([ param.data.view(-1) for param in self.actor.parameters() ]) new_flat_params = flat_params + self._step_size * search_direction self._set_from_flat_params(self.actor, new_flat_params) new_dist = self(b).dist kl = kl_divergence(old_dist, new_dist).mean() # optimize citirc for _ in range(self._optim_critic_iters): value = self.critic(b.obs).flatten() vf_loss = F.mse_loss(b.returns, value) self.optim.zero_grad() vf_loss.backward() self.optim.step() actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) kls.append(kl.item()) # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() return { "loss/actor": actor_losses, "loss/vf": vf_losses, "kl": kls, }
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() dist = self(b).dist v = self.critic(b.obs).flatten() a = to_torch_as(b.act, v) r = to_torch_as(b.returns, v) log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) a_loss = -(log_prob * (r - v).detach()).mean() vf_loss = F.mse_loss(r, v) # type: ignore ent_loss = dist.entropy().mean() loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss.backward() if self._grad_norm is not None: nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm, ) self.optim.step() actor_losses.append(a_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) return { "loss": losses, "loss/actor": actor_losses, "loss/vf": vf_losses, "loss/ent": ent_losses, }
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: self._batch = batch_size r = batch.returns if self._rew_norm and r.std() > self.__eps: batch.returns = (r - r.mean()) / r.std() losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist v = self.critic(b.obs) a = torch.tensor(b.act, device=v.device) r = torch.tensor(b.returns, device=v.device) a_loss = -(dist.log_prob(a) * (r - v).detach()).mean() vf_loss = F.mse_loss(r[:, None], v) ent_loss = dist.entropy().mean() loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss.backward() if self._grad_norm: nn.utils.clip_grad_norm_(list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm) self.optim.step() actor_losses.append(a_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) return { 'loss': losses, 'loss/actor': actor_losses, 'loss/vf': vf_losses, 'loss/ent': ent_losses, }
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._rew_norm: mean, std = batch.rew.mean(), batch.rew.std() if not np.isclose(std, 0, 1e-2): batch.rew = (batch.rew - mean) / std v, v_, old_log_prob = [], [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False): v_.append(self.critic(b.obs_next)) v.append(self.critic(b.obs)) old_log_prob.append(self(b).dist.log_prob( to_torch_as(b.act, v[0]))) v_ = to_numpy(torch.cat(v_, dim=0)) batch = self.compute_episodic_return( batch, v_, gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm) batch.v = torch.cat(v, dim=0).flatten() # old value batch.act = to_torch_as(batch.act, v[0]) batch.logp_old = torch.cat(old_log_prob, dim=0) batch.returns = to_torch_as(batch.returns, v[0]) batch.adv = batch.returns - batch.v if self._rew_norm: mean, std = batch.adv.mean(), batch.adv.std() if not np.isclose(std.item(), 0, 1e-2): batch.adv = (batch.adv - mean) / std return batch
def _compute_returns(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: v_s, v_s_ = [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_s.append(self.critic(b.obs)) v_s_.append(self.critic(b.obs_next)) batch.v_s = torch.cat(v_s, dim=0).flatten() # old value v_s = batch.v_s.cpu().numpy() v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy() # when normalizing values, we do not minus self.ret_rms.mean to be numerically # consistent with OPENAI baselines' value normalization pipeline. Emperical # study also shows that "minus mean" will harm performances a tiny little bit # due to unknown reasons (on Mujoco envs, not confident, though). if self._rew_norm: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) unnormalized_returns, advantages = self.compute_episodic_return( batch, buffer, indice, v_s_, v_s, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: batch.returns = unnormalized_returns / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns batch.returns = to_torch_as(batch.returns, batch.v_s) batch.adv = to_torch_as(advantages, batch.v_s) return batch
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: # update discriminator losses = [] acc_pis = [] acc_exps = [] bsz = len(batch) // self.disc_update_num for b in batch.split(bsz, merge_last=True): logits_pi = self.disc(b) exp_b = self.expert_buffer.sample(bsz)[0] logits_exp = self.disc(exp_b) loss_pi = -F.logsigmoid(-logits_pi).mean() loss_exp = -F.logsigmoid(logits_exp).mean() loss_disc = loss_pi + loss_exp self.disc_optim.zero_grad() loss_disc.backward() self.disc_optim.step() losses.append(loss_disc.item()) acc_pis.append((logits_pi < 0).float().mean().item()) acc_exps.append((logits_exp > 0).float().mean().item()) # update policy res = super().learn(batch, batch_size, repeat, **kwargs) res["loss/disc"] = losses res["stats/acc_pi"] = acc_pis res["stats/acc_exp"] = acc_exps return res
def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: v_s, v_s_, old_log_prob = [], [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_s.append(self.critic(b.obs)) v_s_.append(self.critic(b.obs_next)) old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0]))) batch.v_s = torch.cat(v_s, dim=0).flatten() # old value v_s = to_numpy(batch.v_s) v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten()) if self._rew_norm: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean unnormalized_returns, advantages = self.compute_episodic_return( batch, buffer, indice, v_s_, v_s, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) mean, std = np.mean(advantages), np.std(advantages) advantages = (advantages - mean) / std # per-batch norm else: batch.returns = unnormalized_returns batch.act = to_torch_as(batch.act, batch.v_s) batch.logp_old = torch.cat(old_log_prob, dim=0) batch.returns = to_torch_as(batch.returns, batch.v_s) batch.adv = to_torch_as(advantages, batch.v_s) return batch
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] for step in range(repeat): if self._recompute_adv and step > 0: batch = self._compute_returns(batch, self._buffer, self._indice) for b in batch.split(batch_size, merge_last=True): # calculate loss for actor dist = self(b).dist if self._norm_adv: mean, std = b.adv.mean(), b.adv.std() b.adv = (b.adv - mean) / std # per-batch norm ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) surr1 = ratio * b.adv surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv if self._dual_clip: clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean() else: clip_loss = -torch.min(surr1, surr2).mean() # calculate loss for critic value = self.critic(b.obs).flatten() if self._value_clip: v_clip = b.v_s + (value - b.v_s).clamp( -self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) vf_loss = torch.max(vf1, vf2).mean() else: vf_loss = (b.returns - value).pow(2).mean() # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = clip_loss + self._weight_vf * vf_loss \ - self._weight_ent * ent_loss self.optim.zero_grad() loss.backward() if self._grad_norm: # clip large gradient nn.utils.clip_grad_norm_(list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm) self.optim.step() clip_losses.append(clip_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() return { "loss": losses, "loss/clip": clip_losses, "loss/vf": vf_losses, "loss/ent": ent_losses, }
def test_batch(): batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.cat_(batch) assert batch.obs == [1, 1] assert batch.np.shape == (6, 4) assert batch[0].obs == batch[1].obs batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: assert b.obs == batch[i].obs else: with pytest.raises(AttributeError): batch[i].obs with pytest.raises(AttributeError): b.obs print(batch) batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) assert batch_item.a.b == batch_dict['b'] assert isinstance(batch_item.a.c, float) assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] batch2 = Batch(a=[{ 'b': np.float64(1.0), 'c': np.zeros(1), 'd': Batch(e=np.array(3.0))}]) assert len(batch2) == 1 with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): batch2[1] with pytest.raises(TypeError): batch2[0][0] assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.d.e, np.float64) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch([e for e in batch2]) assert batch2_from_list.a.b == batch2.a.b assert batch2_from_list.a.c == batch2.a.c assert batch2_from_list.a.d.e == batch2.a.d.e assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e for batch_slice in [ batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2_sum = (batch2 + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch: batch = super().process_fn(batch, buffer, indices) old_log_prob = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): old_log_prob.append(self(b).dist.log_prob(b.act)) batch.logp_old = torch.cat(old_log_prob, dim=0) if self._norm_adv: batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() return batch
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._recompute_adv: # buffer input `buffer` and `indice` to be used in `learn()`. self._buffer, self._indice = buffer, indice batch = self._compute_returns(batch, buffer, indice) batch.act = to_torch_as(batch.act, batch.v_s) old_log_prob = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): old_log_prob.append(self(b).dist.log_prob(b.act)) batch.logp_old = torch.cat(old_log_prob, dim=0) return batch
def test_batch(): batch = Batch(obs=[0], np=np.zeros([3, 4])) batch.obs = [1] assert batch.obs == [1] batch.append(batch) assert batch.obs == [1, 1] assert batch.np.shape == (6, 4) assert batch[0].obs == batch[1].obs with pytest.raises(IndexError): batch[2] batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, permute=False)): assert b.obs == batch[i].obs
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._lambda in [0, 1]: return self.compute_episodic_return( batch, None, gamma=self._gamma, gae_lambda=self._lambda) v_ = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_.append(to_numpy(self.critic(b.obs_next))) v_ = np.concatenate(v_, axis=0) return self.compute_episodic_return( batch, v_, gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm)
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): # calculate loss for actor dist = self(b).dist # print(dist.mean[0][0], dist.stddev[0][0]) # print(self.ret_rms.var) if self._norm_adv and False: mean, std = b.adv.mean(), b.adv.std() b.adv = (b.adv - mean) / std # per-batch norm log_prob = dist.log_prob(b.act).reshape(len(b.adv), -1).transpose(0, 1) actor_loss = -(log_prob * b.adv).mean() # calculate loss for critic value = self.critic(b.obs).flatten() vf_loss = F.mse_loss(b.returns, value) # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = actor_loss + self._weight_vf * vf_loss \ - self._weight_ent * ent_loss if self.optim.steps % self.optim.Ts == 0: # Compute fisher, see Martens 2014 self.optim.model.zero_grad() pg_fisher_loss = -log_prob.mean() value_noise = torch.randn(value.size(), device=value.device) sample_value = value + value_noise vf_fisher_loss = -(value - sample_value.detach()).pow(2).mean() fisher_loss = pg_fisher_loss + vf_fisher_loss self.optim.acc_stats = True fisher_loss.backward(retain_graph=True) self.optim.acc_stats = False self.optim.zero_grad() loss.backward() self.optim.step() actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() return { "loss": losses, "loss/actor": actor_losses, "loss/vf": vf_losses, "loss/ent": ent_losses, }
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): dist = self(b).dist value = self.critic(b.obs).flatten() ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) surr1 = ratio * b.adv surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv if self._dual_clip: clip_loss = -torch.max( torch.min(surr1, surr2), self._dual_clip * b.adv ).mean() else: clip_loss = -torch.min(surr1, surr2).mean() clip_losses.append(clip_loss.item()) if self._value_clip: v_clip = b.v_s + (value - b.v_s).clamp( -self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) vf_loss = 0.5 * torch.max(vf1, vf2).mean() else: vf_loss = 0.5 * (b.returns - value).pow(2).mean() vf_losses.append(vf_loss.item()) e_loss = dist.entropy().mean() ent_losses.append(e_loss.item()) loss = clip_loss + self._weight_vf * vf_loss \ - self._weight_ent * e_loss losses.append(loss.item()) self.optim.zero_grad() loss.backward() if self._grad_norm is not None: nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), self._grad_norm) self.optim.step() # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() return { "loss": losses, "loss/clip": clip_losses, "loss/vf": vf_losses, "loss/ent": ent_losses, }
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() with autocast(enabled=self.use_mixed): # calculate loss for actor dist = self(b).dist log_prob = dist.log_prob(b.act).reshape(len(b.adv), -1).transpose( 0, 1) actor_loss = -(log_prob * b.adv).mean() # calculate loss for critic value = self.critic(b.obs).flatten() vf_loss = F.mse_loss(b.returns, value) # vf_loss = F.smooth_l1_loss(b.returns, value) # Experiment # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = actor_loss + self._weight_vf * vf_loss \ - self._weight_ent * ent_loss self.scaler.scale(loss).backward() # loss.backward() if self._grad_norm: # clip large gradient self.scaler.unscale_(self.optim) nn.utils.clip_grad_norm_(list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm) self.scaler.step(self.optim) self.scaler.update() # self.optim.step() actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() return { "loss": losses, "loss/actor": actor_losses, "loss/vf": vf_losses, "loss/ent": ent_losses, }
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._rew_norm: mean, std = batch.rew.mean(), batch.rew.std() if std > self.__eps: batch.rew = (batch.rew - mean) / std if self._lambda in [0, 1]: return self.compute_episodic_return( batch, None, gamma=self._gamma, gae_lambda=self._lambda) v_ = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False): v_.append(self.critic(b.obs_next)) v_ = torch.cat(v_, dim=0).cpu().numpy() return self.compute_episodic_return( batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: losses = [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() dist = self(b).dist a = to_torch_as(b.act, dist.logits) r = to_torch_as(b.returns, dist.logits) log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) loss = -(log_prob * r).mean() loss.backward() self.optim.step() losses.append(loss.item()) return {"loss": losses}
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: losses = [] r = batch.returns if self._rew_norm and not np.isclose(r.std(), 0): batch.returns = (r - r.mean()) / r.std() for _ in range(repeat): for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist a = torch.tensor(b.act, device=dist.logits.device) r = torch.tensor(b.returns, device=dist.logits.device) loss = -(dist.log_prob(a) * r).sum() loss.backward() self.optim.step() losses.append(loss.item()) return {'loss': losses}
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: losses = [] for _ in range(repeat): for minibatch in batch.split(batch_size, merge_last=True): self.optim.zero_grad() result = self(minibatch) dist = result.dist act = to_torch_as(minibatch.act, result.act) ret = to_torch(minibatch.returns, torch.float, result.act.device) log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1) loss = -(log_prob * ret).mean() loss.backward() self.optim.step() losses.append(loss.item()) return {"loss": losses}
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: self._batch = batch_size losses, clip_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size): dist = self(b).dist value = self.critic(b.obs).flatten() ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) surr1 = ratio * b.adv surr2 = ratio.clamp(1. - self._eps_clip, 1. + self._eps_clip) * b.adv if self._dual_clip: clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean() else: clip_loss = -torch.min(surr1, surr2).mean() clip_losses.append(clip_loss.item()) if self._value_clip: v_clip = b.v + (value - b.v).clamp( -self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) vf_loss = .5 * torch.max(vf1, vf2).mean() else: vf_loss = .5 * (b.returns - value).pow(2).mean() vf_losses.append(vf_loss.item()) e_loss = dist.entropy().mean() ent_losses.append(e_loss.item()) loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss losses.append(loss.item()) self.optim.zero_grad() loss.backward() nn.utils.clip_grad_norm_(list( self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm) self.optim.step() return { 'loss': losses, 'loss/clip': clip_losses, 'loss/vf': vf_losses, 'loss/ent': ent_losses, }
def test_batch(): batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.append(batch) assert batch.obs == [1, 1] assert batch.np.shape == (6, 4) assert batch[0].obs == batch[1].obs batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: assert b.obs == batch[i].obs else: with pytest.raises(AttributeError): batch[i].obs with pytest.raises(AttributeError): b.obs print(batch)
def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: v_s_ = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_s_.append(to_numpy(self.critic(b.obs_next))) v_s_ = np.concatenate(v_s_, axis=0) if self._rew_norm: # unnormalize v_s_ v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean unnormalized_returns, _ = self.compute_episodic_return( batch, buffer, indice, v_s_=v_s_, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns return batch
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: losses = [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() result = self(b) dist = result.dist a = to_torch_as(b.act, result.act) r = to_torch_as(b.returns, result.act) log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) loss = -(log_prob * r).mean() loss.backward() self.optim.step() losses.append(loss.item()) # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() return {"loss": losses}
def learn(self, batch:Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, float]: for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() state_values = self.value_net(b.obs, b.act).flatten() advantages = b.advantages.flatten() #actions = b.act dist = self(b).dist actions = to_torch_as(b.act, dist.logits) rewards = to_torch_as(b.returns, dist.logits) # Advantage estimation adv = advantages - state_values adv_squared = (adv.pow(2)).mean() # Value loss v_loss = 0.5 * adv_squared # Policy loss # Update averaged advantage norm # Exponentially weighted advantages exp_advs = # log\pi_\theta(a|s) log_prob = dist.log_prob(a).reshape(len(rewards), -1).transpose(0, 1) p_loss = - 1.0 * (log_prob * exp_advs.detach()).mean() # Combine both losses loss = p_loss + self.vf_coeff * v_loss loss.backward() self.optim.step() return { "policy_loss": p_loss.item(), "vf_loss": v_loss.item(), "total_loss": loss.item(), "vf_explained_var": explained_variance.item() }
def learn(self, batch: Batch, *args: Any, **kwargs: Any) -> Dict[str, float]: n_s, n_a = self.model.n_state, self.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) rew_square_sum = np.zeros((n_s, n_a)) rew_count = np.zeros((n_s, n_a)) for minibatch in batch.split(size=1): obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next trans_count[obs, act, obs_next] += 1 rew_sum[obs, act] += minibatch.rew rew_square_sum[obs, act] += minibatch.rew**2 rew_count[obs, act] += 1 if self._add_done_loop and minibatch.done: # special operation for terminal states: add a self-loop trans_count[obs_next, :, obs_next] += 1 rew_count[obs_next, :] += 1 self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) return { "psrl/rew_mean": float(self.model.rew_mean.mean()), "psrl/rew_std": float(self.model.rew_std.mean()), }
def test_Batch(): """ batch.split() batch.append() len(batch) :return: """ # data is a batch involves 4 transitions data = Batch(obs=np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1]]), rew=np.array([0, 0, 0, 1])) index = [0, 1] # pick the first 2 transition print(data[0]) print(len(data)) print("--------------------") data.append( Batch(obs=np.array([[1, 0, 0], [1, 0, 1], [1, 1, 0]]), rew=np.array([-1, -1, -1, -1]))) print(data) print(len(data)) print("--------------------") # the last batch might has size less than 3 for mini_batch in data.split(size=3, permute=False): print(mini_batch)
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: actor_losses, vf_losses, step_sizes, kls = [], [], [], [] for _ in range(repeat): for minibatch in batch.split(batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient dist = self(minibatch).dist # TODO could come from batch ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) actor_loss = -(ratio * minibatch.adv).mean() flat_grads = self._get_flat_grad(actor_loss, self.actor, retain_graph=True).detach() # direction: calculate natural gradient with torch.no_grad(): old_dist = self(minibatch).dist kl = kl_divergence(old_dist, dist).mean() # calculate first order gradient of kl with respect to theta flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) search_direction = -self._conjugate_gradients( flat_grads, flat_kl_grad, nsteps=10) # stepsize: calculate max stepsize constrained by kl bound step_size = torch.sqrt( 2 * self._delta / (search_direction * self._MVP( search_direction, flat_kl_grad)).sum(0, keepdim=True)) # stepsize: linesearch stepsize with torch.no_grad(): flat_params = torch.cat([ param.data.view(-1) for param in self.actor.parameters() ]) for i in range(self._max_backtracks): new_flat_params = flat_params + step_size * search_direction self._set_from_flat_params(self.actor, new_flat_params) # calculate kl and if in bound, loss actually down new_dist = self(minibatch).dist new_dratio = (new_dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() new_dratio = new_dratio.reshape( new_dratio.size(0), -1).transpose(0, 1) new_actor_loss = -(new_dratio * minibatch.adv).mean() kl = kl_divergence(old_dist, new_dist).mean() if kl < self._delta and new_actor_loss < actor_loss: if i > 0: warnings.warn(f"Backtracking to step {i}.") break elif i < self._max_backtracks - 1: step_size = step_size * self._backtrack_coeff else: self._set_from_flat_params(self.actor, new_flat_params) step_size = torch.tensor([0.0]) warnings.warn( "Line search failed! It seems hyperparamters" " are poor and need to be changed.") # optimize citirc for _ in range(self._optim_critic_iters): value = self.critic(minibatch.obs).flatten() vf_loss = F.mse_loss(minibatch.returns, value) self.optim.zero_grad() vf_loss.backward() self.optim.step() actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) step_sizes.append(step_size.item()) kls.append(kl.item()) return { "loss/actor": actor_losses, "loss/vf": vf_losses, "step_size": step_sizes, "kl": kls, }
def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() assert not Batch(b={'c': {}}).is_empty() assert Batch(b={'c': {}}).is_empty(recurse=True) assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty() assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) assert not Batch(d=1).is_empty() assert not Batch(a=np.float64(1.0)).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None]) assert b.c.dtype == object b = Batch(d=[None], e=[starmap], f=Batch) assert b.d.dtype == b.e.dtype == object and b.f == Batch b = Batch() b.update() assert b.is_empty() b.update(c=[3, 5]) assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) assert 'a' in b and b.a == 3 assert b.pop('a') == 3 assert 'a' not in b with pytest.raises(AssertionError): Batch({1: 2}) assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object with pytest.raises(TypeError): Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch.cat_(batch) assert torch.allclose(batch.a, torch.ones(4, 3)) Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.cat_(batch) assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) assert np.allclose(batch[0].obs, batch[1].obs) batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: assert b.obs == batch[i].obs else: with pytest.raises(AttributeError): batch[i].obs with pytest.raises(AttributeError): b.obs print(batch) batch = Batch(a=np.arange(10)) with pytest.raises(AssertionError): list(batch.split(0)) data = [ (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]), (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]), (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]), (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), ] for size, merge_last, result in data: bs = list(batch.split(size, shuffle=False, merge_last=merge_last)) assert [bs[i].a.tolist() for i in range(len(bs))] == result batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) assert batch_item.a.b == batch_dict['b'] assert isinstance(batch_item.a.c, float) assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] batch2 = Batch(a=[{ 'b': np.float64(1.0), 'c': np.zeros(1), 'd': Batch(e=np.array(3.0))}]) assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] assert Batch(a=set((1, 2, 1))).shape == [] assert batch2.shape[0] == 1 assert 'a' in batch2 and all([i in batch2.a for i in 'bcd']) with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): batch2[1] assert batch2[0].shape == [] with pytest.raises(IndexError): batch2[0][0] with pytest.raises(TypeError): len(batch2[0]) assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.d.e, np.float64) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch([e for e in batch2]) assert batch2_from_list.a.b == batch2.a.b assert batch2_from_list.a.c == batch2.a.c assert batch2_from_list.a.d.e == batch2.a.d.e assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2.a.d.f = {} batch2_sum = (batch2 + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 assert batch2_sum.a.d.f.is_empty() with pytest.raises(TypeError): batch2 += [1] batch3 = Batch(a={ 'c': np.zeros(1), 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))}) batch3.a.d[0] = {'e': 4.0} assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 with pytest.raises(ValueError): batch3.a.d[0] = Batch(f=5.0, g=0.0) with pytest.raises(ValueError): batch3[0] = Batch(a={"c": 2, "e": 1}) # auto convert batch4 = Batch(a=np.array(['a', 'b'])) assert batch4.a.dtype == object # auto convert to object batch4.update(a=np.array(['c', 'd'])) assert list(batch4.a) == ['c', 'd'] assert batch4.a.dtype == object # auto convert to object batch5 = Batch(a=np.array([{'index': 0}])) assert isinstance(batch5.a, Batch) assert np.allclose(batch5.a.index, [0]) batch5.b = np.array([{'index': 1}]) assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) # None is a valid object and can be stored in Batch a = Batch.stack([Batch(a=None), Batch(b=None)]) assert a.a[0] is None and a.a[1] is None assert a.b[0] is None and a.b[1] is None # nx.Graph corner case assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object g1 = nx.Graph() g1.add_nodes_from(list(range(10))) g2 = nx.Graph() g2.add_nodes_from(list(range(20))) assert Batch(a=np.array([g1, g2])).a.dtype == object
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: self._batch = batch_size losses, clip_losses, vf_losses, ent_losses = [], [], [], [] v = [] old_log_prob = [] with torch.no_grad(): for b in batch.split(batch_size, shuffle=False): v.append(self.critic(b.obs)) old_log_prob.append( self(b).dist.log_prob( torch.tensor(b.act, device=v[0].device))) batch.v = torch.cat(v, dim=0) # old value dev = batch.v.device batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev) batch.logp_old = torch.cat(old_log_prob, dim=0) batch.returns = torch.tensor(batch.returns, dtype=torch.float, device=dev).reshape(batch.v.shape) if self._rew_norm: mean, std = batch.returns.mean(), batch.returns.std() if std > self.__eps: batch.returns = (batch.returns - mean) / std batch.adv = batch.returns - batch.v if self._rew_norm: mean, std = batch.adv.mean(), batch.adv.std() if std > self.__eps: batch.adv = (batch.adv - mean) / std for _ in range(repeat): for b in batch.split(batch_size): dist = self(b).dist value = self.critic(b.obs) ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() surr1 = ratio * b.adv surr2 = ratio.clamp(1. - self._eps_clip, 1. + self._eps_clip) * b.adv if self._dual_clip: clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean() else: clip_loss = -torch.min(surr1, surr2).mean() clip_losses.append(clip_loss.item()) if self._value_clip: v_clip = b.v + (value - b.v).clamp(-self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) vf_loss = .5 * torch.max(vf1, vf2).mean() else: vf_loss = .5 * (b.returns - value).pow(2).mean() vf_losses.append(vf_loss.item()) e_loss = dist.entropy().mean() ent_losses.append(e_loss.item()) loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss losses.append(loss.item()) self.optim.zero_grad() loss.backward() nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm) self.optim.step() return { 'loss': losses, 'loss/clip': clip_losses, 'loss/vf': vf_losses, 'loss/ent': ent_losses, }