Exemplo n.º 1
0
 def test_significance_map(self):
     gin.bind_parameter('BoxSpaceSerializer.precision', 3)
     significance_map = serialization_utils.significance_map(
         observation_serializer=space_serializer.create(gym.spaces.Box(
             low=0, high=1, shape=(2, )),
                                                        vocab_size=2),
         action_serializer=space_serializer.create(
             gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2),
         representation_length=20,
     )
     np.testing.assert_array_equal(
         significance_map,
         # obs1, act1, obs2, act2, obs3 cut after 4th symbol.
         [0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0],
     )
 def trajectory_to_training_examples(self, trajectory):
     (repr_length, ) = self.model_input_shape
     seq_mask = np.ones((1, trajectory.num_time_steps - 1))
     (
         reprs, repr_mask
     ) = serialization_utils.serialize_observations_and_actions(
         # Serialization works on batches, so we add a singleton batch dimension.
         trajectory.observations_np[None, ...],
         trajectory.actions_np[None, ...],
         seq_mask,
         self._obs_serializer,
         self._action_serializer,
         repr_length,
     )
     reprs = reprs[0, ...].astype(self.model_input_dtype)
     sig_weights = (
         self._significance_decay**serialization_utils.significance_map(
             self._obs_serializer, self._action_serializer,
             repr_length)[None, ...])
     obs_mask = serialization_utils.observation_mask(
         self._obs_serializer, self._action_serializer, repr_length)
     weights = (sig_weights * obs_mask * repr_mask)[0, ...]
     # (inputs, targets, weights)
     return [(reprs, reprs, weights)]