def test_masks_actions(self): reprs = serialization_utils.serialize_observations_and_actions( # Observations are the same, actions are different. observations=onp.array([[0, 1], [0, 1]]), actions=onp.array([[0], [1]]), **self._serialization_utils_kwargs) obs_mask = serialization_utils.observation_mask( **self._serialization_utils_kwargs) act_mask = serialization_utils.action_mask( **self._serialization_utils_kwargs) onp.testing.assert_array_equal(reprs[0] * obs_mask, reprs[1] * obs_mask) self.assertFalse( onp.array_equal(reprs[0] * act_mask, reprs[1] * act_mask))
def test_observation_and_action_masks_are_valid_and_complementary(self): obs_mask = serialization_utils.observation_mask( **self._serialization_utils_kwargs) self.assertEqual(obs_mask.shape, (self._repr_length, )) self.assertEqual(onp.min(obs_mask), 0) self.assertEqual(onp.max(obs_mask), 1) act_mask = serialization_utils.action_mask( **self._serialization_utils_kwargs) self.assertEqual(act_mask.shape, (self._repr_length, )) self.assertEqual(onp.min(act_mask), 0) self.assertEqual(onp.max(act_mask), 1) onp.testing.assert_array_equal(obs_mask + act_mask, onp.ones(self._repr_length))