Ejemplo n.º 1
0
def test_get_filtered_observation_space():
    # error on wrong number of channels
    rescale_filter = InputFilter()
    rescale_filter.add_observation_filter(
        'observation', 'rescale',
        ObservationRescaleSizeByFactorFilter(
            0.5, RescaleInterpolationType.BILINEAR))
    observation_space = ObservationSpace(np.array([10, 20, 5]))
    with pytest.raises(ValueError):
        filtered_observation_space = rescale_filter.get_filtered_observation_space(
            'observation', observation_space)

    # error on wrong number of dimensions
    observation_space = ObservationSpace(np.array([10, 20, 10, 3]))
    with pytest.raises(ValueError):
        filtered_observation_space = rescale_filter.get_filtered_observation_space(
            'observation', observation_space)

    # make sure the new observation space shape is calculated correctly
    observation_space = ObservationSpace(np.array([10, 20, 3]))
    filtered_observation_space = rescale_filter.get_filtered_observation_space(
        'observation', observation_space)
    assert np.all(filtered_observation_space.shape == np.array([5, 10, 3]))

    # make sure the original observation space is unchanged
    assert np.all(observation_space.shape == np.array([10, 20, 3]))
def test_filter():
    # Keep
    observation_space = VectorObservationSpace(
        3, measurements_names=['a', 'b', 'c'])
    env_response = EnvResponse(next_state={'observation': np.ones([3])},
                               reward=0,
                               game_over=False)
    reduction_filter = InputFilter()
    reduction_filter.add_observation_filter(
        'observation', 'reduce',
        ObservationReductionBySubPartsNameFilter(
            ["a"],
            ObservationReductionBySubPartsNameFilter.ReductionMethod.Keep))

    reduction_filter.get_filtered_observation_space('observation',
                                                    observation_space)
    result = reduction_filter.filter(env_response)[0]
    unfiltered_observation = env_response.next_state['observation']
    filtered_observation = result.next_state['observation']

    # make sure the original observation is unchanged
    assert unfiltered_observation.shape == (3, )

    # validate the shape of the filtered observation
    assert filtered_observation.shape == (1, )

    # Discard
    reduction_filter = InputFilter()
    reduction_filter.add_observation_filter(
        'observation', 'reduce',
        ObservationReductionBySubPartsNameFilter(
            ["a"],
            ObservationReductionBySubPartsNameFilter.ReductionMethod.Discard))
    reduction_filter.get_filtered_observation_space('observation',
                                                    observation_space)
    result = reduction_filter.filter(env_response)[0]
    unfiltered_observation = env_response.next_state['observation']
    filtered_observation = result.next_state['observation']

    # make sure the original observation is unchanged
    assert unfiltered_observation.shape == (3, )

    # validate the shape of the filtered observation
    assert filtered_observation.shape == (2, )
def test_get_filtered_observation_space():
    # error on observation space with values not matching the filter configuration
    uint8_filter = InputFilter()
    uint8_filter.add_observation_filter(
        'observation', 'to_uint8',
        ObservationToUInt8Filter(input_low=0, input_high=200))

    observation_space = ObservationSpace(np.array([1, 2, 3]), 0, 100)
    with pytest.raises(ValueError):
        uint8_filter.get_filtered_observation_space('observation',
                                                    observation_space)

    # verify output observation space is correct
    observation_space = ObservationSpace(np.array([1, 2, 3]), 0, 200)
    result = uint8_filter.get_filtered_observation_space(
        'observation', observation_space)
    assert np.all(result.high == 255)
    assert np.all(result.low == 0)
    assert np.all(result.shape == observation_space.shape)
Ejemplo n.º 4
0
def test_get_filtered_observation_space():
    # error on observation space with shape not matching the filter squeeze axis configuration
    squeeze_filter = InputFilter()
    squeeze_filter.add_observation_filter('observation', 'squeeze',
                                          ObservationSqueezeFilter(axis=3))

    observation_space = ObservationSpace(np.array([20, 1, 30, 3]), 0, 100)
    small_observation_space = ObservationSpace(np.array([20, 1, 30]), 0, 100)
    with pytest.raises(ValueError):
        squeeze_filter.get_filtered_observation_space('observation',
                                                      observation_space)
        squeeze_filter.get_filtered_observation_space('observation',
                                                      small_observation_space)

    # verify output observation space is correct
    observation_space = ObservationSpace(np.array([1, 2, 3, 1]), 0, 200)
    result = squeeze_filter.get_filtered_observation_space(
        'observation', observation_space)
    assert np.all(result.shape == np.array([1, 2, 3]))

    squeeze_filter = InputFilter()
    squeeze_filter.add_observation_filter('observation', 'squeeze',
                                          ObservationSqueezeFilter())

    result = squeeze_filter.get_filtered_observation_space(
        'observation', observation_space)
    assert np.all(result.shape == np.array([2, 3]))
Ejemplo n.º 5
0
def test_get_filtered_observation_space():
    # error on wrong number of channels
    with pytest.raises(ValueError):
        observation_filters = InputFilter()
        observation_filters.add_observation_filter(
            'observation', 'rescale',
            ObservationRescaleToSizeFilter(
                ImageObservationSpace(np.array([5, 10, 5]), high=255),
                RescaleInterpolationType.BILINEAR))

    # mismatch and wrong number of channels
    rescale_filter = InputFilter()
    rescale_filter.add_observation_filter(
        'observation', 'rescale',
        ObservationRescaleToSizeFilter(
            ImageObservationSpace(np.array([5, 10, 3]), high=255),
            RescaleInterpolationType.BILINEAR))

    observation_space = PlanarMapsObservationSpace(np.array([10, 20, 5]),
                                                   low=0,
                                                   high=255)
    with pytest.raises(ValueError):
        rescale_filter.get_filtered_observation_space('observation',
                                                      observation_space)

    # error on wrong number of dimensions
    observation_space = ObservationSpace(np.array([10, 20, 10, 3]), high=255)
    with pytest.raises(ValueError):
        rescale_filter.get_filtered_observation_space('observation',
                                                      observation_space)

    # make sure the new observation space shape is calculated correctly
    observation_space = ImageObservationSpace(np.array([10, 20, 3]), high=255)
    filtered_observation_space = rescale_filter.get_filtered_observation_space(
        'observation', observation_space)
    assert np.all(filtered_observation_space.shape == np.array([5, 10, 3]))

    # make sure the original observation space is unchanged
    assert np.all(observation_space.shape == np.array([10, 20, 3]))
def test_get_filtered_observation_space():
    # Keep
    observation_space = VectorObservationSpace(
        3, measurements_names=['a', 'b', 'c'])
    env_response = EnvResponse(next_state={'observation': np.ones([3])},
                               reward=0,
                               game_over=False)
    reduction_filter = InputFilter()
    reduction_filter.add_observation_filter(
        'observation', 'reduce',
        ObservationReductionBySubPartsNameFilter(
            ["a"],
            ObservationReductionBySubPartsNameFilter.ReductionMethod.Keep))

    filtered_observation_space = reduction_filter.get_filtered_observation_space(
        'observation', observation_space)
    assert np.all(filtered_observation_space.shape == np.array([1]))
    assert filtered_observation_space.measurements_names == ['a']

    # Discard
    observation_space = VectorObservationSpace(
        3, measurements_names=['a', 'b', 'c'])
    env_response = EnvResponse(next_state={'observation': np.ones([3])},
                               reward=0,
                               game_over=False)
    reduction_filter = InputFilter()
    reduction_filter.add_observation_filter(
        'observation', 'reduce',
        ObservationReductionBySubPartsNameFilter(
            ["a"],
            ObservationReductionBySubPartsNameFilter.ReductionMethod.Discard))

    filtered_observation_space = reduction_filter.get_filtered_observation_space(
        'observation', observation_space)
    assert np.all(filtered_observation_space.shape == np.array([2]))
    assert filtered_observation_space.measurements_names == ['b', 'c']
def test_get_filtered_observation_space():
    crop_low = np.array([0, 5, 10])
    crop_high = np.array([5, 10, 20])
    crop_filter = InputFilter()
    crop_filter.add_observation_filter('observation', 'crop', ObservationCropFilter(crop_low, crop_high))

    observation_space = ObservationSpace(np.array([5, 10, 20]))
    filtered_observation_space = crop_filter.get_filtered_observation_space('observation', observation_space)

    # make sure the new observation space shape is calculated correctly
    assert np.all(filtered_observation_space.shape == np.array([5, 5, 10]))

    # make sure the original observation space is unchanged
    assert np.all(observation_space.shape == np.array([5, 10, 20]))

    # crop_high is bigger than the observation space
    high_error_observation_space = ObservationSpace(np.array([3, 8, 14]))
    with pytest.raises(ValueError):
        crop_filter.get_filtered_observation_space('observation', high_error_observation_space)

    # crop_low is bigger than the observation space
    low_error_observation_space = ObservationSpace(np.array([3, 3, 10]))
    with pytest.raises(ValueError):
        crop_filter.get_filtered_observation_space('observation', low_error_observation_space)

    # crop with -1 on some axes
    crop_low = np.array([0, 0, 0])
    crop_high = np.array([5, -1, -1])
    crop_filter = InputFilter()
    crop_filter.add_observation_filter('observation', 'crop', ObservationCropFilter(crop_low, crop_high))

    observation_space = ObservationSpace(np.array([5, 10, 20]))
    filtered_observation_space = crop_filter.get_filtered_observation_space('observation', observation_space)

    # make sure the new observation space shape is calculated correctly
    assert np.all(filtered_observation_space.shape == np.array([5, 10, 20]))