Exemplo n.º 1
0
 def test_serializes_observations_and_actions(self):
     (reprs, mask) = serialization_utils.serialize_observations_and_actions(
         observations=np.array([[0, 1]]),
         actions=np.array([[0]]),
         mask=np.array([[1]]),
         **self._serialization_utils_kwargs)
     self.assertEqual(reprs.shape, (1, self._repr_length))
     self.assertEqual(mask.shape, (1, self._repr_length))
     self.assertGreater(np.sum(mask), 0)
     self.assertEqual(np.max(mask), 1)
Exemplo n.º 2
0
 def _serialize_trajectories(self, observations, actions, reward_mask):
     reprs = serialization_utils.serialize_observations_and_actions(
         observations=observations,
         actions=actions,
         **self._serialization_kwargs)
     # Mask out actions in the representation - otherwise we sample an action
     # based on itself.
     observations = reprs * serialization_utils.observation_mask(
         **self._serialization_kwargs)
     actions = reprs
     return (observations, actions)
Exemplo n.º 3
0
    def test_masks_actions(self):
        reprs = serialization_utils.serialize_observations_and_actions(
            # Observations are the same, actions are different.
            observations=onp.array([[0, 1], [0, 1]]),
            actions=onp.array([[0], [1]]),
            **self._serialization_utils_kwargs)
        obs_mask = serialization_utils.observation_mask(
            **self._serialization_utils_kwargs)
        act_mask = serialization_utils.action_mask(
            **self._serialization_utils_kwargs)

        onp.testing.assert_array_equal(reprs[0] * obs_mask,
                                       reprs[1] * obs_mask)
        self.assertFalse(
            onp.array_equal(reprs[0] * act_mask, reprs[1] * act_mask))
Exemplo n.º 4
0
def _prepare_policy_input(observations, vocab_size, observation_space,
                          action_space):
    """Prepares policy input based on a sequence of observations."""
    if vocab_size is not None:
        (batch_size, n_timesteps) = observations.shape[:2]
        serialization_kwargs = init_serialization(vocab_size,
                                                  observation_space,
                                                  action_space, n_timesteps)
        actions = np.zeros(
            (batch_size, n_timesteps - 1) + action_space.shape,
            dtype=action_space.dtype,
        )
        policy_input = serialization_utils.serialize_observations_and_actions(
            observations=observations, actions=actions, **serialization_kwargs)
        return policy_input
    else:
        return observations
Exemplo n.º 5
0
 def test_masks_length(self):
     (reprs, mask) = serialization_utils.serialize_observations_and_actions(
         observations=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 1]]),
         actions=np.array([[0, 0], [0, 1], [0, 0]]),
         mask=np.array([[1, 0], [1, 1], [1, 1]]),
         **self._serialization_utils_kwargs)
     # Trajectories 1 and 2 are longer than 0.
     self.assertGreater(np.sum(mask[1]), np.sum(mask[0]))
     self.assertGreater(np.sum(mask[2]), np.sum(mask[0]))
     # Trajectory 0 is a common prefix of 1 and 2. 1 and 2 are different.
     np.testing.assert_array_equal(reprs[0] * mask[0], reprs[1] * mask[0])
     np.testing.assert_array_equal(reprs[0] * mask[0], reprs[2] * mask[0])
     self.assertFalse(np.array_equal(reprs[1] * mask[1],
                                     reprs[2] * mask[2]))
     # Trajectories should be padded with 0s.
     np.testing.assert_array_equal(reprs * (1 - mask),
                                   np.zeros((3, self._repr_length)))
 def trajectory_to_training_examples(self, trajectory):
     (repr_length, ) = self.model_input_shape
     seq_mask = np.ones((1, trajectory.num_time_steps - 1))
     (
         reprs, repr_mask
     ) = serialization_utils.serialize_observations_and_actions(
         # Serialization works on batches, so we add a singleton batch dimension.
         trajectory.observations_np[None, ...],
         trajectory.actions_np[None, ...],
         seq_mask,
         self._obs_serializer,
         self._action_serializer,
         repr_length,
     )
     reprs = reprs[0, ...].astype(self.model_input_dtype)
     sig_weights = (
         self._significance_decay**serialization_utils.significance_map(
             self._obs_serializer, self._action_serializer,
             repr_length)[None, ...])
     obs_mask = serialization_utils.observation_mask(
         self._obs_serializer, self._action_serializer, repr_length)
     weights = (sig_weights * obs_mask * repr_mask)[0, ...]
     # (inputs, targets, weights)
     return [(reprs, reprs, weights)]
Exemplo n.º 7
0
 def test_serializes_observations_and_actions(self):
     reprs = serialization_utils.serialize_observations_and_actions(
         observations=onp.array([[0, 1]]),
         actions=onp.array([[0]]),
         **self._serialization_utils_kwargs)
     self.assertEqual(reprs.shape, (1, self._repr_length))