def train(self, epoch: int = -1, writer=None) -> str:
        self.put_model_on_device()
        batch = self.data_loader.get_dataset()
        assert len(batch) != 0

        values = self._net.critic(inputs=batch.observations, train=False).squeeze().detach()
        phi_weights = self._calculate_phi(batch, values).to(self._device).squeeze(-1).detach()

        critic_targets = get_reward_to_go(batch).to(self._device) if self._config.phi_key != 'gae' else \
            (values + phi_weights).detach()
        critic_loss_distribution = self._train_critic_clipped(batch, critic_targets, values)
        actor_loss_distribution = self._train_actor_ppo(batch, phi_weights, writer)

        if writer is not None:
            writer.write_distribution(critic_loss_distribution, "critic_loss")
            writer.write_distribution(Distribution(phi_weights.detach()), "phi_weights")
            writer.write_distribution(Distribution(critic_targets.detach()), "critic_targets")

        if self._config.scheduler_config is not None:
            self._actor_scheduler.step()
            self._critic_scheduler.step()
        self._net.global_step += 1
        self.put_model_back_to_original_device()
        return f" training policy loss {actor_loss_distribution.mean: 0.3e} [{actor_loss_distribution.std: 0.2e}], " \
               f"critic loss {critic_loss_distribution.mean: 0.3e} [{critic_loss_distribution.std: 0.3e}]"
    def test_generalized_advantage_estimate(self):
        # with gae_lambda == 1 and no value --> same as reward-to-go
        rtg_returns = get_generalized_advantage_estimate(
            batch_rewards=self.batch.rewards,
            batch_done=self.batch.done,
            batch_values=[torch.as_tensor(0.)] * len(self.batch),
            discount=1,
            gae_lambda=1)
        for r_e, r_t in zip(rtg_returns, get_reward_to_go(self.batch)):
            self.assertEqual(r_e, r_t)

        one_step_returns = get_generalized_advantage_estimate(
            batch_rewards=self.batch.rewards,
            batch_done=self.batch.done,
            batch_values=[torch.as_tensor(0.)] * len(self.batch),
            discount=1,
            gae_lambda=0)
        targets = [
            self.step_reward if d == 0 else self.end_reward
            for d in self.batch.done
        ]
        for r_e, r_t in zip(one_step_returns, targets):
            self.assertEqual(r_e, r_t)

        gae_returns = get_generalized_advantage_estimate(
            batch_rewards=self.batch.rewards,
            batch_done=self.batch.done,
            batch_values=[torch.as_tensor(0.)] * len(self.batch),
            discount=0.99,
            gae_lambda=0.99)
        for t in range(len(self.batch)):
            self.assertGreaterEqual(gae_returns[t], one_step_returns[t])
            self.assertLessEqual(gae_returns[t], rtg_returns[t])
    def test_get_reward_to_go(self):
        returns = get_reward_to_go(self.batch)
        targets = reversed([
            self.end_reward + t * self.step_reward
            for duration in reversed(self.durations) for t in range(duration)
        ])

        for r_e, r_t in zip(returns, targets):
            self.assertEqual(r_e, r_t)
 def _calculate_phi(self,
                    batch: Dataset,
                    values: torch.Tensor = None) -> torch.Tensor:
     if self._config.phi_key == "return":
         return get_returns(batch)
     elif self._config.phi_key == "reward-to-go":
         return get_reward_to_go(batch)
     elif self._config.phi_key == "gae":
         return get_generalized_advantage_estimate(
             batch_rewards=batch.rewards,
             batch_done=batch.done,
             batch_values=values,
             discount=0.99 if self._config.discount == "default" else
             self._config.discount,
             gae_lambda=0.95 if self._config.gae_lambda == "default" else
             self._config.gae_lambda,
         )
     elif self._config.phi_key == "value-baseline":
         values = self._net.critic(batch.observations,
                                   train=False).detach().squeeze()
         returns = get_reward_to_go(batch)
         return returns - values
     else:
         raise NotImplementedError
    def train(self, epoch: int = -1, writer=None) -> str:
        self.put_model_on_device()
        batch = self.data_loader.get_dataset()
        assert len(batch) != 0

        values = self._net.critic(inputs=batch.observations,
                                  train=False).squeeze().detach()
        phi_weights = self._calculate_phi(batch, values).to(self._device)
        policy_loss = self._train_actor(batch, phi_weights)
        critic_loss = Distribution(
            self._train_critic(batch,
                               get_reward_to_go(batch).to(self._device)))

        if writer is not None:
            writer.set_step(self._net.global_step)
            writer.write_scalar(policy_loss.data, "policy_loss")
            writer.write_distribution(critic_loss, "critic_loss")

        self._net.global_step += 1
        self.put_model_back_to_original_device()
        if self._config.scheduler_config is not None:
            self._actor_scheduler.step()
            self._critic_scheduler.step()
        return f" training policy loss {policy_loss.data: 0.3e}, critic loss {critic_loss.mean: 0.3e}"