예제 #1
0
    def test_masks_actions(self):
        reprs = serialization_utils.serialize_observations_and_actions(
            # Observations are the same, actions are different.
            observations=onp.array([[0, 1], [0, 1]]),
            actions=onp.array([[0], [1]]),
            **self._serialization_utils_kwargs)
        obs_mask = serialization_utils.observation_mask(
            **self._serialization_utils_kwargs)
        act_mask = serialization_utils.action_mask(
            **self._serialization_utils_kwargs)

        onp.testing.assert_array_equal(reprs[0] * obs_mask,
                                       reprs[1] * obs_mask)
        self.assertFalse(
            onp.array_equal(reprs[0] * act_mask, reprs[1] * act_mask))
예제 #2
0
    def test_observation_and_action_masks_are_valid_and_complementary(self):
        obs_mask = serialization_utils.observation_mask(
            **self._serialization_utils_kwargs)
        self.assertEqual(obs_mask.shape, (self._repr_length, ))
        self.assertEqual(onp.min(obs_mask), 0)
        self.assertEqual(onp.max(obs_mask), 1)

        act_mask = serialization_utils.action_mask(
            **self._serialization_utils_kwargs)
        self.assertEqual(act_mask.shape, (self._repr_length, ))
        self.assertEqual(onp.min(act_mask), 0)
        self.assertEqual(onp.max(act_mask), 1)

        onp.testing.assert_array_equal(obs_mask + act_mask,
                                       onp.ones(self._repr_length))