Exemplo n.º 1
0
 def inverse(self, observation):
     """See `AbstractTransform.inverse'."""
     inv_scale = torch.diag_embed(torch.sqrt(self._normalizer.variance))
     observation.state = self._normalizer.inverse(observation.state)
     observation.state_scale_tril = rescale(observation.state_scale_tril,
                                            inv_scale)
     return observation
Exemplo n.º 2
0
 def forward(self, observation):
     """See `AbstractTransform.__call__'."""
     scale = torch.diag_embed(1 / torch.sqrt(self._normalizer.variance))
     observation.reward = self._normalizer(observation.reward)
     observation.reward_scale_tril = rescale(observation.reward_scale_tril,
                                             scale)
     return observation