Exemple #1
0
    def learn(  # type: ignore
            self, batch: Batch, batch_size: int, repeat: int,
            **kwargs: Any) -> Dict[str, List[float]]:
        actor_losses, vf_losses, kls = [], [], []
        for step in range(repeat):
            for b in batch.split(batch_size, merge_last=True):
                # optimize actor
                # direction: calculate villia gradient
                dist = self(b).dist  # TODO could come from batch
                ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
                ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                actor_loss = -(ratio * b.adv).mean()
                flat_grads = self._get_flat_grad(actor_loss,
                                                 self.actor,
                                                 retain_graph=True).detach()

                # direction: calculate natural gradient
                with torch.no_grad():
                    old_dist = self(b).dist

                kl = kl_divergence(old_dist, dist).mean()
                # calculate first order gradient of kl with respect to theta
                flat_kl_grad = self._get_flat_grad(kl,
                                                   self.actor,
                                                   create_graph=True)
                search_direction = -self._conjugate_gradients(
                    flat_grads, flat_kl_grad, nsteps=10)

                # step
                with torch.no_grad():
                    flat_params = torch.cat([
                        param.data.view(-1)
                        for param in self.actor.parameters()
                    ])
                    new_flat_params = flat_params + self._step_size * search_direction
                    self._set_from_flat_params(self.actor, new_flat_params)
                    new_dist = self(b).dist
                    kl = kl_divergence(old_dist, new_dist).mean()

                # optimize citirc
                for _ in range(self._optim_critic_iters):
                    value = self.critic(b.obs).flatten()
                    vf_loss = F.mse_loss(b.returns, value)
                    self.optim.zero_grad()
                    vf_loss.backward()
                    self.optim.step()

                actor_losses.append(actor_loss.item())
                vf_losses.append(vf_loss.item())
                kls.append(kl.item())

        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {
            "loss/actor": actor_losses,
            "loss/vf": vf_losses,
            "kl": kls,
        }
Exemple #2
0
 def learn(  # type: ignore
         self, batch: Batch, batch_size: int, repeat: int,
         **kwargs: Any) -> Dict[str, List[float]]:
     losses, actor_losses, vf_losses, ent_losses = [], [], [], []
     for _ in range(repeat):
         for b in batch.split(batch_size, merge_last=True):
             self.optim.zero_grad()
             dist = self(b).dist
             v = self.critic(b.obs).flatten()
             a = to_torch_as(b.act, v)
             r = to_torch_as(b.returns, v)
             log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
             a_loss = -(log_prob * (r - v).detach()).mean()
             vf_loss = F.mse_loss(r, v)  # type: ignore
             ent_loss = dist.entropy().mean()
             loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
             loss.backward()
             if self._grad_norm is not None:
                 nn.utils.clip_grad_norm_(
                     list(self.actor.parameters()) +
                     list(self.critic.parameters()),
                     max_norm=self._grad_norm,
                 )
             self.optim.step()
             actor_losses.append(a_loss.item())
             vf_losses.append(vf_loss.item())
             ent_losses.append(ent_loss.item())
             losses.append(loss.item())
     return {
         "loss": losses,
         "loss/actor": actor_losses,
         "loss/vf": vf_losses,
         "loss/ent": ent_losses,
     }
Exemple #3
0
 def learn(self, batch: Batch, batch_size: int, repeat: int,
           **kwargs) -> Dict[str, List[float]]:
     self._batch = batch_size
     r = batch.returns
     if self._rew_norm and r.std() > self.__eps:
         batch.returns = (r - r.mean()) / r.std()
     losses, actor_losses, vf_losses, ent_losses = [], [], [], []
     for _ in range(repeat):
         for b in batch.split(batch_size):
             self.optim.zero_grad()
             dist = self(b).dist
             v = self.critic(b.obs)
             a = torch.tensor(b.act, device=v.device)
             r = torch.tensor(b.returns, device=v.device)
             a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
             vf_loss = F.mse_loss(r[:, None], v)
             ent_loss = dist.entropy().mean()
             loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
             loss.backward()
             if self._grad_norm:
                 nn.utils.clip_grad_norm_(list(self.actor.parameters()) +
                                          list(self.critic.parameters()),
                                          max_norm=self._grad_norm)
             self.optim.step()
             actor_losses.append(a_loss.item())
             vf_losses.append(vf_loss.item())
             ent_losses.append(ent_loss.item())
             losses.append(loss.item())
     return {
         'loss': losses,
         'loss/actor': actor_losses,
         'loss/vf': vf_losses,
         'loss/ent': ent_losses,
     }
Exemple #4
0
 def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                indice: np.ndarray) -> Batch:
     if self._rew_norm:
         mean, std = batch.rew.mean(), batch.rew.std()
         if not np.isclose(std, 0, 1e-2):
             batch.rew = (batch.rew - mean) / std
     v, v_, old_log_prob = [], [], []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False):
             v_.append(self.critic(b.obs_next))
             v.append(self.critic(b.obs))
             old_log_prob.append(self(b).dist.log_prob(
                 to_torch_as(b.act, v[0])))
     v_ = to_numpy(torch.cat(v_, dim=0))
     batch = self.compute_episodic_return(
         batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
         rew_norm=self._rew_norm)
     batch.v = torch.cat(v, dim=0).flatten()  # old value
     batch.act = to_torch_as(batch.act, v[0])
     batch.logp_old = torch.cat(old_log_prob, dim=0)
     batch.returns = to_torch_as(batch.returns, v[0])
     batch.adv = batch.returns - batch.v
     if self._rew_norm:
         mean, std = batch.adv.mean(), batch.adv.std()
         if not np.isclose(std.item(), 0, 1e-2):
             batch.adv = (batch.adv - mean) / std
     return batch
Exemple #5
0
 def _compute_returns(self, batch: Batch, buffer: ReplayBuffer,
                      indice: np.ndarray) -> Batch:
     v_s, v_s_ = [], []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False, merge_last=True):
             v_s.append(self.critic(b.obs))
             v_s_.append(self.critic(b.obs_next))
     batch.v_s = torch.cat(v_s, dim=0).flatten()  # old value
     v_s = batch.v_s.cpu().numpy()
     v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
     # when normalizing values, we do not minus self.ret_rms.mean to be numerically
     # consistent with OPENAI baselines' value normalization pipeline. Emperical
     # study also shows that "minus mean" will harm performances a tiny little bit
     # due to unknown reasons (on Mujoco envs, not confident, though).
     if self._rew_norm:  # unnormalize v_s & v_s_
         v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
         v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
     unnormalized_returns, advantages = self.compute_episodic_return(
         batch,
         buffer,
         indice,
         v_s_,
         v_s,
         gamma=self._gamma,
         gae_lambda=self._lambda)
     if self._rew_norm:
         batch.returns = unnormalized_returns / \
             np.sqrt(self.ret_rms.var + self._eps)
         self.ret_rms.update(unnormalized_returns)
     else:
         batch.returns = unnormalized_returns
     batch.returns = to_torch_as(batch.returns, batch.v_s)
     batch.adv = to_torch_as(advantages, batch.v_s)
     return batch
Exemple #6
0
 def learn(  # type: ignore
         self, batch: Batch, batch_size: int, repeat: int,
         **kwargs: Any) -> Dict[str, List[float]]:
     # update discriminator
     losses = []
     acc_pis = []
     acc_exps = []
     bsz = len(batch) // self.disc_update_num
     for b in batch.split(bsz, merge_last=True):
         logits_pi = self.disc(b)
         exp_b = self.expert_buffer.sample(bsz)[0]
         logits_exp = self.disc(exp_b)
         loss_pi = -F.logsigmoid(-logits_pi).mean()
         loss_exp = -F.logsigmoid(logits_exp).mean()
         loss_disc = loss_pi + loss_exp
         self.disc_optim.zero_grad()
         loss_disc.backward()
         self.disc_optim.step()
         losses.append(loss_disc.item())
         acc_pis.append((logits_pi < 0).float().mean().item())
         acc_exps.append((logits_exp > 0).float().mean().item())
     # update policy
     res = super().learn(batch, batch_size, repeat, **kwargs)
     res["loss/disc"] = losses
     res["stats/acc_pi"] = acc_pis
     res["stats/acc_exp"] = acc_exps
     return res
Exemple #7
0
 def process_fn(
     self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
 ) -> Batch:
     v_s, v_s_, old_log_prob = [], [], []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False, merge_last=True):
             v_s.append(self.critic(b.obs))
             v_s_.append(self.critic(b.obs_next))
             old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0])))
     batch.v_s = torch.cat(v_s, dim=0).flatten()  # old value
     v_s = to_numpy(batch.v_s)
     v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
     if self._rew_norm:  # unnormalize v_s & v_s_
         v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
         v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
     unnormalized_returns, advantages = self.compute_episodic_return(
         batch, buffer, indice, v_s_, v_s,
         gamma=self._gamma, gae_lambda=self._lambda)
     if self._rew_norm:
         batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
             np.sqrt(self.ret_rms.var + self._eps)
         self.ret_rms.update(unnormalized_returns)
         mean, std = np.mean(advantages), np.std(advantages)
         advantages = (advantages - mean) / std  # per-batch norm
     else:
         batch.returns = unnormalized_returns
     batch.act = to_torch_as(batch.act, batch.v_s)
     batch.logp_old = torch.cat(old_log_prob, dim=0)
     batch.returns = to_torch_as(batch.returns, batch.v_s)
     batch.adv = to_torch_as(advantages, batch.v_s)
     return batch
Exemple #8
0
    def learn(  # type: ignore
            self, batch: Batch, batch_size: int, repeat: int,
            **kwargs: Any) -> Dict[str, List[float]]:
        losses, clip_losses, vf_losses, ent_losses = [], [], [], []
        for step in range(repeat):
            if self._recompute_adv and step > 0:
                batch = self._compute_returns(batch, self._buffer,
                                              self._indice)
            for b in batch.split(batch_size, merge_last=True):
                # calculate loss for actor
                dist = self(b).dist
                if self._norm_adv:
                    mean, std = b.adv.mean(), b.adv.std()
                    b.adv = (b.adv - mean) / std  # per-batch norm
                ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
                ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                surr1 = ratio * b.adv
                surr2 = ratio.clamp(1.0 - self._eps_clip,
                                    1.0 + self._eps_clip) * b.adv
                if self._dual_clip:
                    clip_loss = -torch.max(torch.min(surr1, surr2),
                                           self._dual_clip * b.adv).mean()
                else:
                    clip_loss = -torch.min(surr1, surr2).mean()
                # calculate loss for critic
                value = self.critic(b.obs).flatten()
                if self._value_clip:
                    v_clip = b.v_s + (value - b.v_s).clamp(
                        -self._eps_clip, self._eps_clip)
                    vf1 = (b.returns - value).pow(2)
                    vf2 = (b.returns - v_clip).pow(2)
                    vf_loss = torch.max(vf1, vf2).mean()
                else:
                    vf_loss = (b.returns - value).pow(2).mean()
                # calculate regularization and overall loss
                ent_loss = dist.entropy().mean()
                loss = clip_loss + self._weight_vf * vf_loss \
                    - self._weight_ent * ent_loss
                self.optim.zero_grad()
                loss.backward()
                if self._grad_norm:  # clip large gradient
                    nn.utils.clip_grad_norm_(list(self.actor.parameters()) +
                                             list(self.critic.parameters()),
                                             max_norm=self._grad_norm)
                self.optim.step()
                clip_losses.append(clip_loss.item())
                vf_losses.append(vf_loss.item())
                ent_losses.append(ent_loss.item())
                losses.append(loss.item())
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {
            "loss": losses,
            "loss/clip": clip_losses,
            "loss/vf": vf_losses,
            "loss/ent": ent_losses,
        }
Exemple #9
0
def test_batch():
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    assert batch.obs == batch["obs"]
    batch.obs = [1]
    assert batch.obs == [1]
    batch.cat_(batch)
    assert batch.obs == [1, 1]
    assert batch.np.shape == (6, 4)
    assert batch[0].obs == batch[1].obs
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, shuffle=False)):
        if i != 5:
            assert b.obs == batch[i].obs
        else:
            with pytest.raises(AttributeError):
                batch[i].obs
            with pytest.raises(AttributeError):
                b.obs
    print(batch)
    batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
    batch_item = Batch({'a': [batch_dict]})[0]
    assert isinstance(batch_item.a.b, np.ndarray)
    assert batch_item.a.b == batch_dict['b']
    assert isinstance(batch_item.a.c, float)
    assert batch_item.a.c == batch_dict['c']
    assert isinstance(batch_item.a.d, torch.Tensor)
    assert batch_item.a.d == batch_dict['d']
    batch2 = Batch(a=[{
        'b': np.float64(1.0),
        'c': np.zeros(1),
        'd': Batch(e=np.array(3.0))}])
    assert len(batch2) == 1
    with pytest.raises(IndexError):
        batch2[-2]
    with pytest.raises(IndexError):
        batch2[1]
    with pytest.raises(TypeError):
        batch2[0][0]
    assert isinstance(batch2[0].a.c, np.ndarray)
    assert isinstance(batch2[0].a.b, np.float64)
    assert isinstance(batch2[0].a.d.e, np.float64)
    batch2_from_list = Batch(list(batch2))
    batch2_from_comp = Batch([e for e in batch2])
    assert batch2_from_list.a.b == batch2.a.b
    assert batch2_from_list.a.c == batch2.a.c
    assert batch2_from_list.a.d.e == batch2.a.d.e
    assert batch2_from_comp.a.b == batch2.a.b
    assert batch2_from_comp.a.c == batch2.a.c
    assert batch2_from_comp.a.d.e == batch2.a.d.e
    for batch_slice in [
            batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
        assert batch_slice.a.b == batch2.a.b
        assert batch_slice.a.c == batch2.a.c
        assert batch_slice.a.d.e == batch2.a.d.e
    batch2_sum = (batch2 + 1.0) * 2
    assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
    assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
    assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
Exemple #10
0
 def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                indices: np.ndarray) -> Batch:
     batch = super().process_fn(batch, buffer, indices)
     old_log_prob = []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False, merge_last=True):
             old_log_prob.append(self(b).dist.log_prob(b.act))
     batch.logp_old = torch.cat(old_log_prob, dim=0)
     if self._norm_adv:
         batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std()
     return batch
Exemple #11
0
 def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                indice: np.ndarray) -> Batch:
     if self._recompute_adv:
         # buffer input `buffer` and `indice` to be used in `learn()`.
         self._buffer, self._indice = buffer, indice
     batch = self._compute_returns(batch, buffer, indice)
     batch.act = to_torch_as(batch.act, batch.v_s)
     old_log_prob = []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False, merge_last=True):
             old_log_prob.append(self(b).dist.log_prob(b.act))
     batch.logp_old = torch.cat(old_log_prob, dim=0)
     return batch
Exemple #12
0
def test_batch():
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    batch.obs = [1]
    assert batch.obs == [1]
    batch.append(batch)
    assert batch.obs == [1, 1]
    assert batch.np.shape == (6, 4)
    assert batch[0].obs == batch[1].obs
    with pytest.raises(IndexError):
        batch[2]
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, permute=False)):
        assert b.obs == batch[i].obs
Exemple #13
0
 def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                indice: np.ndarray) -> Batch:
     if self._lambda in [0, 1]:
         return self.compute_episodic_return(
             batch, None, gamma=self._gamma, gae_lambda=self._lambda)
     v_ = []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False, merge_last=True):
             v_.append(to_numpy(self.critic(b.obs_next)))
     v_ = np.concatenate(v_, axis=0)
     return self.compute_episodic_return(
         batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
         rew_norm=self._rew_norm)
Exemple #14
0
    def learn(  # type: ignore
        self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
    ) -> Dict[str, List[float]]:
        losses, actor_losses, vf_losses, ent_losses = [], [], [], []
        for _ in range(repeat):
            for b in batch.split(batch_size, merge_last=True):
                # calculate loss for actor
                dist = self(b).dist
                # print(dist.mean[0][0], dist.stddev[0][0])
                # print(self.ret_rms.var)
                if self._norm_adv and False:
                    mean, std = b.adv.mean(), b.adv.std()
                    b.adv = (b.adv - mean) / std  # per-batch norm
                log_prob = dist.log_prob(b.act).reshape(len(b.adv), -1).transpose(0, 1)
                actor_loss = -(log_prob * b.adv).mean()
                # calculate loss for critic
                value = self.critic(b.obs).flatten()
                vf_loss = F.mse_loss(b.returns, value)
                # calculate regularization and overall loss
                ent_loss = dist.entropy().mean()
                loss = actor_loss + self._weight_vf * vf_loss \
                    - self._weight_ent * ent_loss
                if self.optim.steps % self.optim.Ts == 0:
                    # Compute fisher, see Martens 2014
                    self.optim.model.zero_grad()
                    pg_fisher_loss = -log_prob.mean()
                    value_noise = torch.randn(value.size(), device=value.device)
                    sample_value = value + value_noise
                    vf_fisher_loss = -(value - sample_value.detach()).pow(2).mean()
                    fisher_loss = pg_fisher_loss + vf_fisher_loss
                    self.optim.acc_stats = True
                    fisher_loss.backward(retain_graph=True)
                    self.optim.acc_stats = False
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                actor_losses.append(actor_loss.item())
                vf_losses.append(vf_loss.item())
                ent_losses.append(ent_loss.item())
                losses.append(loss.item())
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {
            "loss": losses,
            "loss/actor": actor_losses,
            "loss/vf": vf_losses,
            "loss/ent": ent_losses,
        }
Exemple #15
0
    def learn(  # type: ignore
        self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
    ) -> Dict[str, List[float]]:
        losses, clip_losses, vf_losses, ent_losses = [], [], [], []
        for _ in range(repeat):
            for b in batch.split(batch_size, merge_last=True):
                dist = self(b).dist
                value = self.critic(b.obs).flatten()
                ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
                ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                surr1 = ratio * b.adv
                surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
                if self._dual_clip:
                    clip_loss = -torch.max(
                        torch.min(surr1, surr2), self._dual_clip * b.adv
                    ).mean()
                else:
                    clip_loss = -torch.min(surr1, surr2).mean()
                clip_losses.append(clip_loss.item())
                if self._value_clip:
                    v_clip = b.v_s + (value - b.v_s).clamp(
                        -self._eps_clip, self._eps_clip)
                    vf1 = (b.returns - value).pow(2)
                    vf2 = (b.returns - v_clip).pow(2)
                    vf_loss = 0.5 * torch.max(vf1, vf2).mean()
                else:
                    vf_loss = 0.5 * (b.returns - value).pow(2).mean()
                vf_losses.append(vf_loss.item())
                e_loss = dist.entropy().mean()
                ent_losses.append(e_loss.item())
                loss = clip_loss + self._weight_vf * vf_loss \
                    - self._weight_ent * e_loss
                losses.append(loss.item())
                self.optim.zero_grad()
                loss.backward()
                if self._grad_norm is not None:
                    nn.utils.clip_grad_norm_(
                        list(self.actor.parameters()) + list(self.critic.parameters()),
                        self._grad_norm)
                self.optim.step()
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {
            "loss": losses,
            "loss/clip": clip_losses,
            "loss/vf": vf_losses,
            "loss/ent": ent_losses,
        }
Exemple #16
0
    def learn(  # type: ignore
            self, batch: Batch, batch_size: int, repeat: int,
            **kwargs: Any) -> Dict[str, List[float]]:
        losses, actor_losses, vf_losses, ent_losses = [], [], [], []
        for _ in range(repeat):
            for b in batch.split(batch_size, merge_last=True):
                self.optim.zero_grad()
                with autocast(enabled=self.use_mixed):
                    # calculate loss for actor
                    dist = self(b).dist
                    log_prob = dist.log_prob(b.act).reshape(len(b.adv),
                                                            -1).transpose(
                                                                0, 1)
                    actor_loss = -(log_prob * b.adv).mean()
                    # calculate loss for critic
                    value = self.critic(b.obs).flatten()
                    vf_loss = F.mse_loss(b.returns, value)
                    # vf_loss = F.smooth_l1_loss(b.returns, value)     # Experiment
                    # calculate regularization and overall loss
                    ent_loss = dist.entropy().mean()
                    loss = actor_loss + self._weight_vf * vf_loss \
                        - self._weight_ent * ent_loss

                self.scaler.scale(loss).backward()
                # loss.backward()

                if self._grad_norm:  # clip large gradient
                    self.scaler.unscale_(self.optim)
                    nn.utils.clip_grad_norm_(list(self.actor.parameters()) +
                                             list(self.critic.parameters()),
                                             max_norm=self._grad_norm)

                self.scaler.step(self.optim)
                self.scaler.update()
                # self.optim.step()
                actor_losses.append(actor_loss.item())
                vf_losses.append(vf_loss.item())
                ent_losses.append(ent_loss.item())
                losses.append(loss.item())
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {
            "loss": losses,
            "loss/actor": actor_losses,
            "loss/vf": vf_losses,
            "loss/ent": ent_losses,
        }
Exemple #17
0
 def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                indice: np.ndarray) -> Batch:
     if self._rew_norm:
         mean, std = batch.rew.mean(), batch.rew.std()
         if std > self.__eps:
             batch.rew = (batch.rew - mean) / std
     if self._lambda in [0, 1]:
         return self.compute_episodic_return(
             batch, None, gamma=self._gamma, gae_lambda=self._lambda)
     v_ = []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False):
             v_.append(self.critic(b.obs_next))
     v_ = torch.cat(v_, dim=0).cpu().numpy()
     return self.compute_episodic_return(
         batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
Exemple #18
0
 def learn(  # type: ignore
         self, batch: Batch, batch_size: int, repeat: int,
         **kwargs: Any) -> Dict[str, List[float]]:
     losses = []
     for _ in range(repeat):
         for b in batch.split(batch_size, merge_last=True):
             self.optim.zero_grad()
             dist = self(b).dist
             a = to_torch_as(b.act, dist.logits)
             r = to_torch_as(b.returns, dist.logits)
             log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
             loss = -(log_prob * r).mean()
             loss.backward()
             self.optim.step()
             losses.append(loss.item())
     return {"loss": losses}
Exemple #19
0
 def learn(self, batch: Batch, batch_size: int, repeat: int,
           **kwargs) -> Dict[str, List[float]]:
     losses = []
     r = batch.returns
     if self._rew_norm and not np.isclose(r.std(), 0):
         batch.returns = (r - r.mean()) / r.std()
     for _ in range(repeat):
         for b in batch.split(batch_size):
             self.optim.zero_grad()
             dist = self(b).dist
             a = torch.tensor(b.act, device=dist.logits.device)
             r = torch.tensor(b.returns, device=dist.logits.device)
             loss = -(dist.log_prob(a) * r).sum()
             loss.backward()
             self.optim.step()
             losses.append(loss.item())
     return {'loss': losses}
Exemple #20
0
    def learn(  # type: ignore
        self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
    ) -> Dict[str, List[float]]:
        losses = []
        for _ in range(repeat):
            for minibatch in batch.split(batch_size, merge_last=True):
                self.optim.zero_grad()
                result = self(minibatch)
                dist = result.dist
                act = to_torch_as(minibatch.act, result.act)
                ret = to_torch(minibatch.returns, torch.float, result.act.device)
                log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
                loss = -(log_prob * ret).mean()
                loss.backward()
                self.optim.step()
                losses.append(loss.item())

        return {"loss": losses}
Exemple #21
0
 def learn(self, batch: Batch, batch_size: int, repeat: int,
           **kwargs) -> Dict[str, List[float]]:
     self._batch = batch_size
     losses, clip_losses, vf_losses, ent_losses = [], [], [], []
     for _ in range(repeat):
         for b in batch.split(batch_size):
             dist = self(b).dist
             value = self.critic(b.obs).flatten()
             ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
             ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
             surr1 = ratio * b.adv
             surr2 = ratio.clamp(1. - self._eps_clip,
                                 1. + self._eps_clip) * b.adv
             if self._dual_clip:
                 clip_loss = -torch.max(torch.min(surr1, surr2),
                                        self._dual_clip * b.adv).mean()
             else:
                 clip_loss = -torch.min(surr1, surr2).mean()
             clip_losses.append(clip_loss.item())
             if self._value_clip:
                 v_clip = b.v + (value - b.v).clamp(
                     -self._eps_clip, self._eps_clip)
                 vf1 = (b.returns - value).pow(2)
                 vf2 = (b.returns - v_clip).pow(2)
                 vf_loss = .5 * torch.max(vf1, vf2).mean()
             else:
                 vf_loss = .5 * (b.returns - value).pow(2).mean()
             vf_losses.append(vf_loss.item())
             e_loss = dist.entropy().mean()
             ent_losses.append(e_loss.item())
             loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
             losses.append(loss.item())
             self.optim.zero_grad()
             loss.backward()
             nn.utils.clip_grad_norm_(list(
                 self.actor.parameters()) + list(self.critic.parameters()),
                 self._max_grad_norm)
             self.optim.step()
     return {
         'loss': losses,
         'loss/clip': clip_losses,
         'loss/vf': vf_losses,
         'loss/ent': ent_losses,
     }
Exemple #22
0
def test_batch():
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    assert batch.obs == batch["obs"]
    batch.obs = [1]
    assert batch.obs == [1]
    batch.append(batch)
    assert batch.obs == [1, 1]
    assert batch.np.shape == (6, 4)
    assert batch[0].obs == batch[1].obs
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, shuffle=False)):
        if i != 5:
            assert b.obs == batch[i].obs
        else:
            with pytest.raises(AttributeError):
                batch[i].obs
            with pytest.raises(AttributeError):
                b.obs
    print(batch)
Exemple #23
0
 def process_fn(
     self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
 ) -> Batch:
     v_s_ = []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False, merge_last=True):
             v_s_.append(to_numpy(self.critic(b.obs_next)))
     v_s_ = np.concatenate(v_s_, axis=0)
     if self._rew_norm:  # unnormalize v_s_
         v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
     unnormalized_returns, _ = self.compute_episodic_return(
         batch, buffer, indice, v_s_=v_s_,
         gamma=self._gamma, gae_lambda=self._lambda)
     if self._rew_norm:
         batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
             np.sqrt(self.ret_rms.var + self._eps)
         self.ret_rms.update(unnormalized_returns)
     else:
         batch.returns = unnormalized_returns
     return batch
Exemple #24
0
    def learn(  # type: ignore
            self, batch: Batch, batch_size: int, repeat: int,
            **kwargs: Any) -> Dict[str, List[float]]:
        losses = []
        for _ in range(repeat):
            for b in batch.split(batch_size, merge_last=True):
                self.optim.zero_grad()
                result = self(b)
                dist = result.dist
                a = to_torch_as(b.act, result.act)
                r = to_torch_as(b.returns, result.act)
                log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
                loss = -(log_prob * r).mean()
                loss.backward()
                self.optim.step()
                losses.append(loss.item())
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {"loss": losses}
    def learn(self, batch:Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, float]:
        
        for _ in range(repeat):
            for b in batch.split(batch_size, merge_last=True):
                self.optim.zero_grad()
                state_values = self.value_net(b.obs, b.act).flatten()
                advantages = b.advantages.flatten()
                #actions = b.act
                dist = self(b).dist
                actions = to_torch_as(b.act, dist.logits)
                rewards = to_torch_as(b.returns, dist.logits)
                # Advantage estimation
                adv = advantages - state_values
                adv_squared = (adv.pow(2)).mean()

                # Value loss
                v_loss = 0.5 * adv_squared

                # Policy loss
                # Update averaged advantage norm

                # Exponentially weighted advantages
                exp_advs = 
                # log\pi_\theta(a|s)
                log_prob = dist.log_prob(a).reshape(len(rewards), -1).transpose(0, 1)
                p_loss = - 1.0 * (log_prob * exp_advs.detach()).mean()

                # Combine both losses
                loss = p_loss + self.vf_coeff * v_loss

                
                loss.backward()
                self.optim.step()

        return {
            "policy_loss": p_loss.item(),
            "vf_loss": v_loss.item(),
            "total_loss": loss.item(),
            "vf_explained_var": explained_variance.item()
        }
Exemple #26
0
 def learn(self, batch: Batch, *args: Any,
           **kwargs: Any) -> Dict[str, float]:
     n_s, n_a = self.model.n_state, self.model.n_action
     trans_count = np.zeros((n_s, n_a, n_s))
     rew_sum = np.zeros((n_s, n_a))
     rew_square_sum = np.zeros((n_s, n_a))
     rew_count = np.zeros((n_s, n_a))
     for minibatch in batch.split(size=1):
         obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next
         trans_count[obs, act, obs_next] += 1
         rew_sum[obs, act] += minibatch.rew
         rew_square_sum[obs, act] += minibatch.rew**2
         rew_count[obs, act] += 1
         if self._add_done_loop and minibatch.done:
             # special operation for terminal states: add a self-loop
             trans_count[obs_next, :, obs_next] += 1
             rew_count[obs_next, :] += 1
     self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count)
     return {
         "psrl/rew_mean": float(self.model.rew_mean.mean()),
         "psrl/rew_std": float(self.model.rew_std.mean()),
     }
Exemple #27
0
def test_Batch():
    """
    batch.split()
    batch.append()
    len(batch)
    :return:
    """
    # data is a batch involves 4 transitions
    data = Batch(obs=np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1]]),
                 rew=np.array([0, 0, 0, 1]))
    index = [0, 1]  # pick the first 2 transition
    print(data[0])
    print(len(data))
    print("--------------------")
    data.append(
        Batch(obs=np.array([[1, 0, 0], [1, 0, 1], [1, 1, 0]]),
              rew=np.array([-1, -1, -1, -1])))
    print(data)
    print(len(data))
    print("--------------------")
    # the last batch might has size less than 3
    for mini_batch in data.split(size=3, permute=False):
        print(mini_batch)
Exemple #28
0
    def learn(  # type: ignore
            self, batch: Batch, batch_size: int, repeat: int,
            **kwargs: Any) -> Dict[str, List[float]]:
        actor_losses, vf_losses, step_sizes, kls = [], [], [], []
        for _ in range(repeat):
            for minibatch in batch.split(batch_size, merge_last=True):
                # optimize actor
                # direction: calculate villia gradient
                dist = self(minibatch).dist  # TODO could come from batch
                ratio = (dist.log_prob(minibatch.act) -
                         minibatch.logp_old).exp().float()
                ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                actor_loss = -(ratio * minibatch.adv).mean()
                flat_grads = self._get_flat_grad(actor_loss,
                                                 self.actor,
                                                 retain_graph=True).detach()

                # direction: calculate natural gradient
                with torch.no_grad():
                    old_dist = self(minibatch).dist

                kl = kl_divergence(old_dist, dist).mean()
                # calculate first order gradient of kl with respect to theta
                flat_kl_grad = self._get_flat_grad(kl,
                                                   self.actor,
                                                   create_graph=True)
                search_direction = -self._conjugate_gradients(
                    flat_grads, flat_kl_grad, nsteps=10)

                # stepsize: calculate max stepsize constrained by kl bound
                step_size = torch.sqrt(
                    2 * self._delta / (search_direction * self._MVP(
                        search_direction, flat_kl_grad)).sum(0, keepdim=True))

                # stepsize: linesearch stepsize
                with torch.no_grad():
                    flat_params = torch.cat([
                        param.data.view(-1)
                        for param in self.actor.parameters()
                    ])
                    for i in range(self._max_backtracks):
                        new_flat_params = flat_params + step_size * search_direction
                        self._set_from_flat_params(self.actor, new_flat_params)
                        # calculate kl and if in bound, loss actually down
                        new_dist = self(minibatch).dist
                        new_dratio = (new_dist.log_prob(minibatch.act) -
                                      minibatch.logp_old).exp().float()
                        new_dratio = new_dratio.reshape(
                            new_dratio.size(0), -1).transpose(0, 1)
                        new_actor_loss = -(new_dratio * minibatch.adv).mean()
                        kl = kl_divergence(old_dist, new_dist).mean()

                        if kl < self._delta and new_actor_loss < actor_loss:
                            if i > 0:
                                warnings.warn(f"Backtracking to step {i}.")
                            break
                        elif i < self._max_backtracks - 1:
                            step_size = step_size * self._backtrack_coeff
                        else:
                            self._set_from_flat_params(self.actor,
                                                       new_flat_params)
                            step_size = torch.tensor([0.0])
                            warnings.warn(
                                "Line search failed! It seems hyperparamters"
                                " are poor and need to be changed.")

                # optimize citirc
                for _ in range(self._optim_critic_iters):
                    value = self.critic(minibatch.obs).flatten()
                    vf_loss = F.mse_loss(minibatch.returns, value)
                    self.optim.zero_grad()
                    vf_loss.backward()
                    self.optim.step()

                actor_losses.append(actor_loss.item())
                vf_losses.append(vf_loss.item())
                step_sizes.append(step_size.item())
                kls.append(kl.item())

        return {
            "loss/actor": actor_losses,
            "loss/vf": vf_losses,
            "step_size": step_sizes,
            "kl": kls,
        }
Exemple #29
0
def test_batch():
    assert list(Batch()) == []
    assert Batch().is_empty()
    assert not Batch(b={'c': {}}).is_empty()
    assert Batch(b={'c': {}}).is_empty(recurse=True)
    assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
    assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
    assert not Batch(d=1).is_empty()
    assert not Batch(a=np.float64(1.0)).is_empty()
    assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
    assert not Batch(a=[1, 2, 3]).is_empty()
    b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None])
    assert b.c.dtype == object
    b = Batch(d=[None], e=[starmap], f=Batch)
    assert b.d.dtype == b.e.dtype == object and b.f == Batch
    b = Batch()
    b.update()
    assert b.is_empty()
    b.update(c=[3, 5])
    assert np.allclose(b.c, [3, 5])
    # mimic the behavior of dict.update, where kwargs can overwrite keys
    b.update({'a': 2}, a=3)
    assert 'a' in b and b.a == 3
    assert b.pop('a') == 3
    assert 'a' not in b
    with pytest.raises(AssertionError):
        Batch({1: 2})
    assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
    with pytest.raises(TypeError):
        Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
    batch = Batch(a=[torch.ones(3), torch.ones(3)])
    assert torch.allclose(batch.a, torch.ones(2, 3))
    batch.cat_(batch)
    assert torch.allclose(batch.a, torch.ones(4, 3))
    Batch(a=[])
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    assert batch.obs == batch["obs"]
    batch.obs = [1]
    assert batch.obs == [1]
    batch.cat_(batch)
    assert np.allclose(batch.obs, [1, 1])
    assert batch.np.shape == (6, 4)
    assert np.allclose(batch[0].obs, batch[1].obs)
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, shuffle=False)):
        if i != 5:
            assert b.obs == batch[i].obs
        else:
            with pytest.raises(AttributeError):
                batch[i].obs
            with pytest.raises(AttributeError):
                b.obs
    print(batch)
    batch = Batch(a=np.arange(10))
    with pytest.raises(AssertionError):
        list(batch.split(0))
    data = [
        (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
        (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
        (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]),
        (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]),
        (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
        (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
        (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]),
        (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
    ]
    for size, merge_last, result in data:
        bs = list(batch.split(size, shuffle=False, merge_last=merge_last))
        assert [bs[i].a.tolist() for i in range(len(bs))] == result
    batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
    batch_item = Batch({'a': [batch_dict]})[0]
    assert isinstance(batch_item.a.b, np.ndarray)
    assert batch_item.a.b == batch_dict['b']
    assert isinstance(batch_item.a.c, float)
    assert batch_item.a.c == batch_dict['c']
    assert isinstance(batch_item.a.d, torch.Tensor)
    assert batch_item.a.d == batch_dict['d']
    batch2 = Batch(a=[{
        'b': np.float64(1.0),
        'c': np.zeros(1),
        'd': Batch(e=np.array(3.0))}])
    assert len(batch2) == 1
    assert Batch().shape == []
    assert Batch(a=1).shape == []
    assert Batch(a=set((1, 2, 1))).shape == []
    assert batch2.shape[0] == 1
    assert 'a' in batch2 and all([i in batch2.a for i in 'bcd'])
    with pytest.raises(IndexError):
        batch2[-2]
    with pytest.raises(IndexError):
        batch2[1]
    assert batch2[0].shape == []
    with pytest.raises(IndexError):
        batch2[0][0]
    with pytest.raises(TypeError):
        len(batch2[0])
    assert isinstance(batch2[0].a.c, np.ndarray)
    assert isinstance(batch2[0].a.b, np.float64)
    assert isinstance(batch2[0].a.d.e, np.float64)
    batch2_from_list = Batch(list(batch2))
    batch2_from_comp = Batch([e for e in batch2])
    assert batch2_from_list.a.b == batch2.a.b
    assert batch2_from_list.a.c == batch2.a.c
    assert batch2_from_list.a.d.e == batch2.a.d.e
    assert batch2_from_comp.a.b == batch2.a.b
    assert batch2_from_comp.a.c == batch2.a.c
    assert batch2_from_comp.a.d.e == batch2.a.d.e
    for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
        assert batch_slice.a.b == batch2.a.b
        assert batch_slice.a.c == batch2.a.c
        assert batch_slice.a.d.e == batch2.a.d.e
    batch2.a.d.f = {}
    batch2_sum = (batch2 + 1.0) * 2
    assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
    assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
    assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
    assert batch2_sum.a.d.f.is_empty()
    with pytest.raises(TypeError):
        batch2 += [1]
    batch3 = Batch(a={
        'c': np.zeros(1),
        'd': Batch(e=np.array([0.0]), f=np.array([3.0]))})
    batch3.a.d[0] = {'e': 4.0}
    assert batch3.a.d.e[0] == 4.0
    batch3.a.d[0] = Batch(f=5.0)
    assert batch3.a.d.f[0] == 5.0
    with pytest.raises(ValueError):
        batch3.a.d[0] = Batch(f=5.0, g=0.0)
    with pytest.raises(ValueError):
        batch3[0] = Batch(a={"c": 2, "e": 1})
    # auto convert
    batch4 = Batch(a=np.array(['a', 'b']))
    assert batch4.a.dtype == object  # auto convert to object
    batch4.update(a=np.array(['c', 'd']))
    assert list(batch4.a) == ['c', 'd']
    assert batch4.a.dtype == object  # auto convert to object
    batch5 = Batch(a=np.array([{'index': 0}]))
    assert isinstance(batch5.a, Batch)
    assert np.allclose(batch5.a.index, [0])
    batch5.b = np.array([{'index': 1}])
    assert isinstance(batch5.b, Batch)
    assert np.allclose(batch5.b.index, [1])

    # None is a valid object and can be stored in Batch
    a = Batch.stack([Batch(a=None), Batch(b=None)])
    assert a.a[0] is None and a.a[1] is None
    assert a.b[0] is None and a.b[1] is None

    # nx.Graph corner case
    assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object
    g1 = nx.Graph()
    g1.add_nodes_from(list(range(10)))
    g2 = nx.Graph()
    g2.add_nodes_from(list(range(20)))
    assert Batch(a=np.array([g1, g2])).a.dtype == object
Exemple #30
0
 def learn(self, batch: Batch, batch_size: int, repeat: int,
           **kwargs) -> Dict[str, List[float]]:
     self._batch = batch_size
     losses, clip_losses, vf_losses, ent_losses = [], [], [], []
     v = []
     old_log_prob = []
     with torch.no_grad():
         for b in batch.split(batch_size, shuffle=False):
             v.append(self.critic(b.obs))
             old_log_prob.append(
                 self(b).dist.log_prob(
                     torch.tensor(b.act, device=v[0].device)))
     batch.v = torch.cat(v, dim=0)  # old value
     dev = batch.v.device
     batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
     batch.logp_old = torch.cat(old_log_prob, dim=0)
     batch.returns = torch.tensor(batch.returns,
                                  dtype=torch.float,
                                  device=dev).reshape(batch.v.shape)
     if self._rew_norm:
         mean, std = batch.returns.mean(), batch.returns.std()
         if std > self.__eps:
             batch.returns = (batch.returns - mean) / std
     batch.adv = batch.returns - batch.v
     if self._rew_norm:
         mean, std = batch.adv.mean(), batch.adv.std()
         if std > self.__eps:
             batch.adv = (batch.adv - mean) / std
     for _ in range(repeat):
         for b in batch.split(batch_size):
             dist = self(b).dist
             value = self.critic(b.obs)
             ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
             surr1 = ratio * b.adv
             surr2 = ratio.clamp(1. - self._eps_clip,
                                 1. + self._eps_clip) * b.adv
             if self._dual_clip:
                 clip_loss = -torch.max(torch.min(surr1, surr2),
                                        self._dual_clip * b.adv).mean()
             else:
                 clip_loss = -torch.min(surr1, surr2).mean()
             clip_losses.append(clip_loss.item())
             if self._value_clip:
                 v_clip = b.v + (value - b.v).clamp(-self._eps_clip,
                                                    self._eps_clip)
                 vf1 = (b.returns - value).pow(2)
                 vf2 = (b.returns - v_clip).pow(2)
                 vf_loss = .5 * torch.max(vf1, vf2).mean()
             else:
                 vf_loss = .5 * (b.returns - value).pow(2).mean()
             vf_losses.append(vf_loss.item())
             e_loss = dist.entropy().mean()
             ent_losses.append(e_loss.item())
             loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
             losses.append(loss.item())
             self.optim.zero_grad()
             loss.backward()
             nn.utils.clip_grad_norm_(
                 list(self.actor.parameters()) +
                 list(self.critic.parameters()), self._max_grad_norm)
             self.optim.step()
     return {
         'loss': losses,
         'loss/clip': clip_losses,
         'loss/vf': vf_losses,
         'loss/ent': ent_losses,
     }