예제 #1
0
 def policy_inputs(self, trajectory, values):
     """Create inputs to policy model from a TrajectoryNp and values."""
     td = self._added_policy_slice_length
     advantage = computation_utils.calculate_advantage(
         trajectory.rewards, trajectory.returns, values, self._task.gamma,
         td)
     return (trajectory.observations, trajectory.actions, advantage)
예제 #2
0
 def test_calculate_advantage(self):
     """Test calculating advantage."""
     rewards = np.array([[1, 1, 1]], dtype=np.float32)
     returns = np.array([[3, 2, 1]], dtype=np.float32)
     values = np.array([[2, 2, 2]], dtype=np.float32)
     adv1 = computation_utils.calculate_advantage(rewards, returns, values,
                                                  1, 0)
     self.assertEqual(adv1.shape, (1, 3))
     self.assertEqual(adv1[0, 0], 1)
     self.assertEqual(adv1[0, 1], 0)
     self.assertEqual(adv1[0, 2], -1)
     adv2 = computation_utils.calculate_advantage(rewards, returns, values,
                                                  1, 1)
     self.assertEqual(adv2.shape, (1, 2))
     self.assertEqual(adv2[0, 0], 1)
     self.assertEqual(adv2[0, 1], 1)
예제 #3
0
 def policy_inputs(self, trajectory, values):
     """Create inputs to policy model from a TrajectoryNp and values."""
     # How much TD to use is determined by the added policy slice length,
     # as the policy batches need to be this much longer to calculate TD.
     td = self._added_policy_slice_length
     advantage = computation_utils.calculate_advantage(
         trajectory.rewards, trajectory.returns, values, self._task.gamma,
         td)
     awr_weights = np.minimum(np.exp(advantage / self._beta), self._w_max)
     # Observations should be the same length as awr_weights - so if we are
     # using td_advantage, we need to cut td-man out from the end.
     obs = trajectory.observations
     obs = obs[:, :-td] if td > 0 else obs
     act = trajectory.actions
     act = act[:, :-td] if td > 0 else act
     assert len(awr_weights.shape) == 2  # [batch_size, length]
     assert act.shape[0:2] == awr_weights.shape
     assert obs.shape[0:2] == awr_weights.shape
     return (obs, act, awr_weights)