def test_inits_serialization(self): serialization_kwargs = ppo.init_serialization( vocab_size=4, observation_space=gym.spaces.Box(shape=(2, 3), low=0, high=1), action_space=gym.spaces.Discrete(n=3), n_timesteps=6, ) # Check that we can call a function from serialization_utils with those # kwargs. serialization_utils.observation_mask(**serialization_kwargs)
def _serialize_trajectories(self, observations, actions, reward_mask): reprs = serialization_utils.serialize_observations_and_actions( observations=observations, actions=actions, **self._serialization_kwargs) # Mask out actions in the representation - otherwise we sample an action # based on itself. observations = reprs * serialization_utils.observation_mask( **self._serialization_kwargs) actions = reprs return (observations, actions)
def test_masks_actions(self): reprs = serialization_utils.serialize_observations_and_actions( # Observations are the same, actions are different. observations=onp.array([[0, 1], [0, 1]]), actions=onp.array([[0], [1]]), **self._serialization_utils_kwargs) obs_mask = serialization_utils.observation_mask( **self._serialization_utils_kwargs) act_mask = serialization_utils.action_mask( **self._serialization_utils_kwargs) onp.testing.assert_array_equal(reprs[0] * obs_mask, reprs[1] * obs_mask) self.assertFalse( onp.array_equal(reprs[0] * act_mask, reprs[1] * act_mask))
def test_observation_and_action_masks_are_valid_and_complementary(self): obs_mask = serialization_utils.observation_mask( **self._serialization_utils_kwargs) self.assertEqual(obs_mask.shape, (self._repr_length, )) self.assertEqual(onp.min(obs_mask), 0) self.assertEqual(onp.max(obs_mask), 1) act_mask = serialization_utils.action_mask( **self._serialization_utils_kwargs) self.assertEqual(act_mask.shape, (self._repr_length, )) self.assertEqual(onp.min(act_mask), 0) self.assertEqual(onp.max(act_mask), 1) onp.testing.assert_array_equal(obs_mask + act_mask, onp.ones(self._repr_length))
def trajectory_to_training_examples(self, trajectory): (repr_length, ) = self.model_input_shape seq_mask = np.ones((1, trajectory.num_time_steps - 1)) ( reprs, repr_mask ) = serialization_utils.serialize_observations_and_actions( # Serialization works on batches, so we add a singleton batch dimension. trajectory.observations_np[None, ...], trajectory.actions_np[None, ...], seq_mask, self._obs_serializer, self._action_serializer, repr_length, ) reprs = reprs[0, ...].astype(self.model_input_dtype) sig_weights = ( self._significance_decay**serialization_utils.significance_map( self._obs_serializer, self._action_serializer, repr_length)[None, ...]) obs_mask = serialization_utils.observation_mask( self._obs_serializer, self._action_serializer, repr_length) weights = (sig_weights * obs_mask * repr_mask)[0, ...] # (inputs, targets, weights) return [(reprs, reprs, weights)]