Пример #1
0
def test_get_filtered_observation_space(rgb_to_y_filter):
    # error on observation space which are not RGB
    observation_space = ObservationSpace(np.array([1, 2, 4]), 0, 100)
    with pytest.raises(ValueError):
        rgb_to_y_filter.get_filtered_observation_space('observation',
                                                       observation_space)

    observation_space = ObservationSpace(np.array([1, 2, 3]), 0, 100)
    result = rgb_to_y_filter.get_filtered_observation_space(
        'observation', observation_space)
    assert np.all(result.shape == np.array([1, 2]))
 def validate_input_observation_space(
         self, input_observation_space: ObservationSpace):
     crop_high = self._replace_negative_one_in_crop_size(
         self.crop_high, input_observation_space.shape)
     crop_low = self._replace_negative_one_in_crop_size(
         self.crop_low, input_observation_space.shape)
     if np.any(crop_high > input_observation_space.shape) or \
             np.any(crop_low > input_observation_space.shape):
         raise ValueError(
             "The cropping values are outside of the observation space")
     if not input_observation_space.is_point_in_space_shape(crop_low) or \
             not input_observation_space.is_point_in_space_shape(crop_high - 1):
         raise ValueError(
             "The cropping indices are outside of the observation space")
 def get_filtered_observation_space(
         self,
         input_observation_space: ObservationSpace) -> ObservationSpace:
     if self.stacking_axis == -1:
         input_observation_space.shape = np.append(
             input_observation_space.shape,
             values=[self.stack_size],
             axis=0)
     else:
         input_observation_space.shape = np.insert(
             input_observation_space.shape,
             obj=self.stacking_axis,
             values=[self.stack_size],
             axis=0)
     return input_observation_space
 def get_filtered_observation_space(
         self,
         input_observation_space: ObservationSpace) -> ObservationSpace:
     input_observation_space.shape[:2] = (
         input_observation_space.shape[:2] *
         self.rescale_factor).astype('int')
     return input_observation_space
 def get_filtered_observation_space(
         self,
         input_observation_space: ObservationSpace) -> ObservationSpace:
     dummy_tensor = np.random.rand(*tuple(input_observation_space.shape))
     input_observation_space.shape = dummy_tensor.squeeze(
         axis=self.axis).shape
     return input_observation_space
Пример #6
0
 def get_filtered_observation_space(
         self,
         input_observation_space: ObservationSpace) -> ObservationSpace:
     input_observation_space.shape[0] = (
         input_observation_space.shape[0] /
         NUMBER_OF_LIDAR_VALUES_IN_EACH_SECTOR)
     return input_observation_space
Пример #7
0
def test_observation_space():
    observation_space = ObservationSpace(np.array([1, 10]), -10, 10)

    # testing that contains works
    assert observation_space.contains(np.ones([1, 10]))
    assert not observation_space.contains(np.ones([2, 10]))
    assert not observation_space.contains(np.ones([1, 10]) * 100)
    assert not observation_space.contains(np.ones([1, 1, 10]))

    # is_valid_index
    assert observation_space.is_valid_index(np.array([0, 9]))
    assert observation_space.is_valid_index(np.array([0, 0]))
    assert not observation_space.is_valid_index(np.array([1, 8]))
    assert not observation_space.is_valid_index(np.array([0, 10]))
    assert not observation_space.is_valid_index(np.array([-1, 6]))
def test_get_filtered_observation_space():
    # error on observation space with values not matching the filter configuration
    uint8_filter = InputFilter()
    uint8_filter.add_observation_filter(
        'observation', 'to_uint8',
        ObservationToUInt8Filter(input_low=0, input_high=200))

    observation_space = ObservationSpace(np.array([1, 2, 3]), 0, 100)
    with pytest.raises(ValueError):
        uint8_filter.get_filtered_observation_space('observation',
                                                    observation_space)

    # verify output observation space is correct
    observation_space = ObservationSpace(np.array([1, 2, 3]), 0, 200)
    result = uint8_filter.get_filtered_observation_space(
        'observation', observation_space)
    assert np.all(result.high == 255)
    assert np.all(result.low == 0)
    assert np.all(result.shape == observation_space.shape)
    def get_filtered_observation_space(
            self,
            input_observation_space: ObservationSpace) -> ObservationSpace:
        # replace -1 with the max size
        crop_high = self._replace_negative_one_in_crop_size(
            self.crop_high, input_observation_space.shape)
        crop_low = self._replace_negative_one_in_crop_size(
            self.crop_low, input_observation_space.shape)

        input_observation_space.shape = crop_high - crop_low
        return input_observation_space
Пример #10
0
def test_observation_space():
    observation_space = ObservationSpace(np.array([1, 10]), -10, 10)

    # testing that val_matches_space_definition works
    assert observation_space.val_matches_space_definition(np.ones([1, 10]))
    assert not observation_space.val_matches_space_definition(np.ones([2, 10]))
    assert not observation_space.val_matches_space_definition(
        np.ones([1, 10]) * 100)
    assert not observation_space.val_matches_space_definition(
        np.ones([1, 1, 10]))

    # is_point_in_space_shape
    assert observation_space.is_point_in_space_shape(np.array([0, 9]))
    assert observation_space.is_point_in_space_shape(np.array([0, 0]))
    assert not observation_space.is_point_in_space_shape(np.array([1, 8]))
    assert not observation_space.is_point_in_space_shape(np.array([0, 10]))
    assert not observation_space.is_point_in_space_shape(np.array([-1, 6]))
    def get_filtered_observation_space(
            self,
            input_observation_space: ObservationSpace) -> ObservationSpace:
        axis_size = input_observation_space.shape[self.axis_origin]
        input_observation_space.shape = np.delete(
            input_observation_space.shape, self.axis_origin)
        if self.axis_target == -1:
            input_observation_space.shape = np.append(
                input_observation_space.shape, axis_size)
        elif self.axis_target < -1:
            input_observation_space.shape = np.insert(
                input_observation_space.shape, self.axis_target + 1, axis_size)
        else:
            input_observation_space.shape = np.insert(
                input_observation_space.shape, self.axis_target, axis_size)

        # move the channels axis according to the axis change
        if isinstance(input_observation_space, PlanarMapsObservationSpace):
            if input_observation_space.channels_axis == self.axis_origin:
                input_observation_space.channels_axis = self.axis_target
            elif input_observation_space.channels_axis == self.axis_target:
                input_observation_space.channels_axis = self.axis_origin
            elif self.axis_origin < input_observation_space.channels_axis < self.axis_target:
                input_observation_space.channels_axis -= 1
            elif self.axis_target < input_observation_space.channels_axis < self.axis_origin:
                input_observation_space.channels_axis += 1

        return input_observation_space
 def validate_input_observation_space(
         self, input_observation_space: ObservationSpace):
     if len(
             self.stack
     ) > 0 and not input_observation_space.val_matches_space_definition(
             self.stack[-1]):
         raise ValueError(
             "The given input observation space is different than the observations already stored in"
             "the filters memory")
     if input_observation_space.num_dimensions <= self.stacking_axis:
         raise ValueError(
             "The stacking axis is larger than the number of dimensions in the observation space"
         )
def test_get_filtered_observation_space():
    # error on wrong number of channels
    rescale_filter = InputFilter()
    rescale_filter.add_observation_filter(
        'observation', 'rescale', ObservationRescaleSizeByFactorFilter(0.5))
    observation_space = ObservationSpace(np.array([10, 20, 5]))
    with pytest.raises(ValueError):
        filtered_observation_space = rescale_filter.get_filtered_observation_space(
            'observation', observation_space)

    # error on wrong number of dimensions
    observation_space = ObservationSpace(np.array([10, 20, 10, 3]))
    with pytest.raises(ValueError):
        filtered_observation_space = rescale_filter.get_filtered_observation_space(
            'observation', observation_space)

    # make sure the new observation space shape is calculated correctly
    observation_space = ObservationSpace(np.array([10, 20, 3]))
    filtered_observation_space = rescale_filter.get_filtered_observation_space(
        'observation', observation_space)
    assert np.all(filtered_observation_space.shape == np.array([5, 10, 3]))

    # make sure the original observation space is unchanged
    assert np.all(observation_space.shape == np.array([10, 20, 3]))
def test_get_filtered_observation_space():
    crop_low = np.array([0, 5, 10])
    crop_high = np.array([5, 10, 20])
    crop_filter = InputFilter()
    crop_filter.add_observation_filter('observation', 'crop', ObservationCropFilter(crop_low, crop_high))

    observation_space = ObservationSpace(np.array([5, 10, 20]))
    filtered_observation_space = crop_filter.get_filtered_observation_space('observation', observation_space)

    # make sure the new observation space shape is calculated correctly
    assert np.all(filtered_observation_space.shape == np.array([5, 5, 10]))

    # make sure the original observation space is unchanged
    assert np.all(observation_space.shape == np.array([5, 10, 20]))

    # crop_high is bigger than the observation space
    high_error_observation_space = ObservationSpace(np.array([3, 8, 14]))
    with pytest.raises(ValueError):
        crop_filter.get_filtered_observation_space('observation', high_error_observation_space)

    # crop_low is bigger than the observation space
    low_error_observation_space = ObservationSpace(np.array([3, 3, 10]))
    with pytest.raises(ValueError):
        crop_filter.get_filtered_observation_space('observation', low_error_observation_space)

    # crop with -1 on some axes
    crop_low = np.array([0, 0, 0])
    crop_high = np.array([5, -1, -1])
    crop_filter = InputFilter()
    crop_filter.add_observation_filter('observation', 'crop', ObservationCropFilter(crop_low, crop_high))

    observation_space = ObservationSpace(np.array([5, 10, 20]))
    filtered_observation_space = crop_filter.get_filtered_observation_space('observation', observation_space)

    # make sure the new observation space shape is calculated correctly
    assert np.all(filtered_observation_space.shape == np.array([5, 10, 20]))
def test_get_filtered_observation_space(stack_filter, env_response):
    observation_space = ObservationSpace(np.array([5, 10, 20]))
    filtered_observation_space = stack_filter.get_filtered_observation_space(
        'observation', observation_space)

    # make sure the new observation space shape is calculated correctly
    assert np.all(filtered_observation_space.shape == np.array([5, 10, 20, 4]))

    # make sure the original observation space is unchanged
    assert np.all(observation_space.shape == np.array([5, 10, 20]))

    # call after stack is already created with non-matching shape -> error
    result = stack_filter.filter(env_response)[0]
    with pytest.raises(ValueError):
        filtered_observation_space = stack_filter.get_filtered_observation_space(
            'observation', observation_space)
Пример #16
0
def test_get_filtered_observation_space():
    # error on wrong number of channels
    with pytest.raises(ValueError):
        observation_filters = InputFilter()
        observation_filters.add_observation_filter(
            'observation', 'rescale',
            ObservationRescaleToSizeFilter(
                ImageObservationSpace(np.array([5, 10, 5]), high=255),
                RescaleInterpolationType.BILINEAR))

    # mismatch and wrong number of channels
    rescale_filter = InputFilter()
    rescale_filter.add_observation_filter(
        'observation', 'rescale',
        ObservationRescaleToSizeFilter(
            ImageObservationSpace(np.array([5, 10, 3]), high=255),
            RescaleInterpolationType.BILINEAR))

    observation_space = PlanarMapsObservationSpace(np.array([10, 20, 5]),
                                                   low=0,
                                                   high=255)
    with pytest.raises(ValueError):
        rescale_filter.get_filtered_observation_space('observation',
                                                      observation_space)

    # error on wrong number of dimensions
    observation_space = ObservationSpace(np.array([10, 20, 10, 3]), high=255)
    with pytest.raises(ValueError):
        rescale_filter.get_filtered_observation_space('observation',
                                                      observation_space)

    # make sure the new observation space shape is calculated correctly
    observation_space = ImageObservationSpace(np.array([10, 20, 3]), high=255)
    filtered_observation_space = rescale_filter.get_filtered_observation_space(
        'observation', observation_space)
    assert np.all(filtered_observation_space.shape == np.array([5, 10, 3]))

    # make sure the original observation space is unchanged
    assert np.all(observation_space.shape == np.array([10, 20, 3]))
Пример #17
0
 def get_filtered_observation_space(
         self,
         input_observation_space: ObservationSpace) -> ObservationSpace:
     input_observation_space.low = 0
     input_observation_space.high = 255
     return input_observation_space
 def get_filtered_observation_space(
         self,
         input_observation_space: ObservationSpace) -> ObservationSpace:
     input_observation_space.shape = self.output_observation_space.shape
     return input_observation_space
Пример #19
0
 def get_filtered_observation_space(
         self,
         input_observation_space: ObservationSpace) -> ObservationSpace:
     input_observation_space.shape = input_observation_space.shape[:-1]
     return input_observation_space