def validate_input_observation_space(
         self, input_observation_space: ObservationSpace):
     crop_high = self._replace_negative_one_in_crop_size(
         self.crop_high, input_observation_space.shape)
     crop_low = self._replace_negative_one_in_crop_size(
         self.crop_low, input_observation_space.shape)
     if np.any(crop_high > input_observation_space.shape) or \
             np.any(crop_low > input_observation_space.shape):
         raise ValueError(
             "The cropping values are outside of the observation space")
     if not input_observation_space.is_point_in_space_shape(crop_low) or \
             not input_observation_space.is_point_in_space_shape(crop_high - 1):
         raise ValueError(
             "The cropping indices are outside of the observation space")
Пример #2
0
def test_observation_space():
    observation_space = ObservationSpace(np.array([1, 10]), -10, 10)

    # testing that val_matches_space_definition works
    assert observation_space.val_matches_space_definition(np.ones([1, 10]))
    assert not observation_space.val_matches_space_definition(np.ones([2, 10]))
    assert not observation_space.val_matches_space_definition(
        np.ones([1, 10]) * 100)
    assert not observation_space.val_matches_space_definition(
        np.ones([1, 1, 10]))

    # is_point_in_space_shape
    assert observation_space.is_point_in_space_shape(np.array([0, 9]))
    assert observation_space.is_point_in_space_shape(np.array([0, 0]))
    assert not observation_space.is_point_in_space_shape(np.array([1, 8]))
    assert not observation_space.is_point_in_space_shape(np.array([0, 10]))
    assert not observation_space.is_point_in_space_shape(np.array([-1, 6]))