def validate_input_observation_space( self, input_observation_space: ObservationSpace): crop_high = self._replace_negative_one_in_crop_size( self.crop_high, input_observation_space.shape) crop_low = self._replace_negative_one_in_crop_size( self.crop_low, input_observation_space.shape) if np.any(crop_high > input_observation_space.shape) or \ np.any(crop_low > input_observation_space.shape): raise ValueError( "The cropping values are outside of the observation space") if not input_observation_space.is_point_in_space_shape(crop_low) or \ not input_observation_space.is_point_in_space_shape(crop_high - 1): raise ValueError( "The cropping indices are outside of the observation space")
def test_observation_space(): observation_space = ObservationSpace(np.array([1, 10]), -10, 10) # testing that val_matches_space_definition works assert observation_space.val_matches_space_definition(np.ones([1, 10])) assert not observation_space.val_matches_space_definition(np.ones([2, 10])) assert not observation_space.val_matches_space_definition( np.ones([1, 10]) * 100) assert not observation_space.val_matches_space_definition( np.ones([1, 1, 10])) # is_point_in_space_shape assert observation_space.is_point_in_space_shape(np.array([0, 9])) assert observation_space.is_point_in_space_shape(np.array([0, 0])) assert not observation_space.is_point_in_space_shape(np.array([1, 8])) assert not observation_space.is_point_in_space_shape(np.array([0, 10])) assert not observation_space.is_point_in_space_shape(np.array([-1, 6]))