def test_significance_map(self): gin.bind_parameter('BoxSpaceSerializer.precision', 3) significance_map = serialization_utils.significance_map( observation_serializer=space_serializer.create(gym.spaces.Box( low=0, high=1, shape=(2, )), vocab_size=2), action_serializer=space_serializer.create( gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2), representation_length=20, ) np.testing.assert_array_equal( significance_map, # obs1, act1, obs2, act2, obs3 cut after 4th symbol. [0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0], )
def trajectory_to_training_examples(self, trajectory): (repr_length, ) = self.model_input_shape seq_mask = np.ones((1, trajectory.num_time_steps - 1)) ( reprs, repr_mask ) = serialization_utils.serialize_observations_and_actions( # Serialization works on batches, so we add a singleton batch dimension. trajectory.observations_np[None, ...], trajectory.actions_np[None, ...], seq_mask, self._obs_serializer, self._action_serializer, repr_length, ) reprs = reprs[0, ...].astype(self.model_input_dtype) sig_weights = ( self._significance_decay**serialization_utils.significance_map( self._obs_serializer, self._action_serializer, repr_length)[None, ...]) obs_mask = serialization_utils.observation_mask( self._obs_serializer, self._action_serializer, repr_length) weights = (sig_weights * obs_mask * repr_mask)[0, ...] # (inputs, targets, weights) return [(reprs, reprs, weights)]