def test_get_filtered_observation_space(): # error on wrong number of channels rescale_filter = InputFilter() rescale_filter.add_observation_filter( 'observation', 'rescale', ObservationRescaleSizeByFactorFilter( 0.5, RescaleInterpolationType.BILINEAR)) observation_space = ObservationSpace(np.array([10, 20, 5])) with pytest.raises(ValueError): filtered_observation_space = rescale_filter.get_filtered_observation_space( 'observation', observation_space) # error on wrong number of dimensions observation_space = ObservationSpace(np.array([10, 20, 10, 3])) with pytest.raises(ValueError): filtered_observation_space = rescale_filter.get_filtered_observation_space( 'observation', observation_space) # make sure the new observation space shape is calculated correctly observation_space = ObservationSpace(np.array([10, 20, 3])) filtered_observation_space = rescale_filter.get_filtered_observation_space( 'observation', observation_space) assert np.all(filtered_observation_space.shape == np.array([5, 10, 3])) # make sure the original observation space is unchanged assert np.all(observation_space.shape == np.array([10, 20, 3]))
def test_filter(): # Keep observation_space = VectorObservationSpace( 3, measurements_names=['a', 'b', 'c']) env_response = EnvResponse(next_state={'observation': np.ones([3])}, reward=0, game_over=False) reduction_filter = InputFilter() reduction_filter.add_observation_filter( 'observation', 'reduce', ObservationReductionBySubPartsNameFilter( ["a"], ObservationReductionBySubPartsNameFilter.ReductionMethod.Keep)) reduction_filter.get_filtered_observation_space('observation', observation_space) result = reduction_filter.filter(env_response)[0] unfiltered_observation = env_response.next_state['observation'] filtered_observation = result.next_state['observation'] # make sure the original observation is unchanged assert unfiltered_observation.shape == (3, ) # validate the shape of the filtered observation assert filtered_observation.shape == (1, ) # Discard reduction_filter = InputFilter() reduction_filter.add_observation_filter( 'observation', 'reduce', ObservationReductionBySubPartsNameFilter( ["a"], ObservationReductionBySubPartsNameFilter.ReductionMethod.Discard)) reduction_filter.get_filtered_observation_space('observation', observation_space) result = reduction_filter.filter(env_response)[0] unfiltered_observation = env_response.next_state['observation'] filtered_observation = result.next_state['observation'] # make sure the original observation is unchanged assert unfiltered_observation.shape == (3, ) # validate the shape of the filtered observation assert filtered_observation.shape == (2, )
def test_get_filtered_observation_space(): # error on observation space with values not matching the filter configuration uint8_filter = InputFilter() uint8_filter.add_observation_filter( 'observation', 'to_uint8', ObservationToUInt8Filter(input_low=0, input_high=200)) observation_space = ObservationSpace(np.array([1, 2, 3]), 0, 100) with pytest.raises(ValueError): uint8_filter.get_filtered_observation_space('observation', observation_space) # verify output observation space is correct observation_space = ObservationSpace(np.array([1, 2, 3]), 0, 200) result = uint8_filter.get_filtered_observation_space( 'observation', observation_space) assert np.all(result.high == 255) assert np.all(result.low == 0) assert np.all(result.shape == observation_space.shape)
def test_get_filtered_observation_space(): # error on observation space with shape not matching the filter squeeze axis configuration squeeze_filter = InputFilter() squeeze_filter.add_observation_filter('observation', 'squeeze', ObservationSqueezeFilter(axis=3)) observation_space = ObservationSpace(np.array([20, 1, 30, 3]), 0, 100) small_observation_space = ObservationSpace(np.array([20, 1, 30]), 0, 100) with pytest.raises(ValueError): squeeze_filter.get_filtered_observation_space('observation', observation_space) squeeze_filter.get_filtered_observation_space('observation', small_observation_space) # verify output observation space is correct observation_space = ObservationSpace(np.array([1, 2, 3, 1]), 0, 200) result = squeeze_filter.get_filtered_observation_space( 'observation', observation_space) assert np.all(result.shape == np.array([1, 2, 3])) squeeze_filter = InputFilter() squeeze_filter.add_observation_filter('observation', 'squeeze', ObservationSqueezeFilter()) result = squeeze_filter.get_filtered_observation_space( 'observation', observation_space) assert np.all(result.shape == np.array([2, 3]))
def test_get_filtered_observation_space(): # error on wrong number of channels with pytest.raises(ValueError): observation_filters = InputFilter() observation_filters.add_observation_filter( 'observation', 'rescale', ObservationRescaleToSizeFilter( ImageObservationSpace(np.array([5, 10, 5]), high=255), RescaleInterpolationType.BILINEAR)) # mismatch and wrong number of channels rescale_filter = InputFilter() rescale_filter.add_observation_filter( 'observation', 'rescale', ObservationRescaleToSizeFilter( ImageObservationSpace(np.array([5, 10, 3]), high=255), RescaleInterpolationType.BILINEAR)) observation_space = PlanarMapsObservationSpace(np.array([10, 20, 5]), low=0, high=255) with pytest.raises(ValueError): rescale_filter.get_filtered_observation_space('observation', observation_space) # error on wrong number of dimensions observation_space = ObservationSpace(np.array([10, 20, 10, 3]), high=255) with pytest.raises(ValueError): rescale_filter.get_filtered_observation_space('observation', observation_space) # make sure the new observation space shape is calculated correctly observation_space = ImageObservationSpace(np.array([10, 20, 3]), high=255) filtered_observation_space = rescale_filter.get_filtered_observation_space( 'observation', observation_space) assert np.all(filtered_observation_space.shape == np.array([5, 10, 3])) # make sure the original observation space is unchanged assert np.all(observation_space.shape == np.array([10, 20, 3]))
def test_get_filtered_observation_space(): # Keep observation_space = VectorObservationSpace( 3, measurements_names=['a', 'b', 'c']) env_response = EnvResponse(next_state={'observation': np.ones([3])}, reward=0, game_over=False) reduction_filter = InputFilter() reduction_filter.add_observation_filter( 'observation', 'reduce', ObservationReductionBySubPartsNameFilter( ["a"], ObservationReductionBySubPartsNameFilter.ReductionMethod.Keep)) filtered_observation_space = reduction_filter.get_filtered_observation_space( 'observation', observation_space) assert np.all(filtered_observation_space.shape == np.array([1])) assert filtered_observation_space.measurements_names == ['a'] # Discard observation_space = VectorObservationSpace( 3, measurements_names=['a', 'b', 'c']) env_response = EnvResponse(next_state={'observation': np.ones([3])}, reward=0, game_over=False) reduction_filter = InputFilter() reduction_filter.add_observation_filter( 'observation', 'reduce', ObservationReductionBySubPartsNameFilter( ["a"], ObservationReductionBySubPartsNameFilter.ReductionMethod.Discard)) filtered_observation_space = reduction_filter.get_filtered_observation_space( 'observation', observation_space) assert np.all(filtered_observation_space.shape == np.array([2])) assert filtered_observation_space.measurements_names == ['b', 'c']
def test_get_filtered_observation_space(): crop_low = np.array([0, 5, 10]) crop_high = np.array([5, 10, 20]) crop_filter = InputFilter() crop_filter.add_observation_filter('observation', 'crop', ObservationCropFilter(crop_low, crop_high)) observation_space = ObservationSpace(np.array([5, 10, 20])) filtered_observation_space = crop_filter.get_filtered_observation_space('observation', observation_space) # make sure the new observation space shape is calculated correctly assert np.all(filtered_observation_space.shape == np.array([5, 5, 10])) # make sure the original observation space is unchanged assert np.all(observation_space.shape == np.array([5, 10, 20])) # crop_high is bigger than the observation space high_error_observation_space = ObservationSpace(np.array([3, 8, 14])) with pytest.raises(ValueError): crop_filter.get_filtered_observation_space('observation', high_error_observation_space) # crop_low is bigger than the observation space low_error_observation_space = ObservationSpace(np.array([3, 3, 10])) with pytest.raises(ValueError): crop_filter.get_filtered_observation_space('observation', low_error_observation_space) # crop with -1 on some axes crop_low = np.array([0, 0, 0]) crop_high = np.array([5, -1, -1]) crop_filter = InputFilter() crop_filter.add_observation_filter('observation', 'crop', ObservationCropFilter(crop_low, crop_high)) observation_space = ObservationSpace(np.array([5, 10, 20])) filtered_observation_space = crop_filter.get_filtered_observation_space('observation', observation_space) # make sure the new observation space shape is calculated correctly assert np.all(filtered_observation_space.shape == np.array([5, 10, 20]))