def test_filter():
    filter = BoxDiscretization(9)

    # passing an output space that is wrong
    with pytest.raises(ValueError):
        filter.validate_output_action_space(DiscreteActionSpace(10))

    # 1 dimensional box
    output_space = BoxActionSpace(1, 5, 15)
    input_space = filter.get_unfiltered_action_space(output_space)

    assert filter.target_actions == [[5.], [6.25], [7.5], [8.75], [10.],
                                     [11.25], [12.5], [13.75], [15.]]
    assert input_space.actions == list(range(9))

    action = 2

    result = filter.filter(action)
    assert result == [7.5]
    assert output_space.contains(result)

    # 2 dimensional box
    filter = BoxDiscretization(3)
    output_space = BoxActionSpace(2, 5, 15)
    input_space = filter.get_unfiltered_action_space(output_space)

    assert filter.target_actions == [[5., 5.], [5., 10.], [5., 15.], [10., 5.],
                                     [10., 10.], [10., 15.], [15., 5.],
                                     [15., 10.], [15., 15.]]
    assert input_space.actions == list(range(9))

    action = 2

    result = filter.filter(action)
    assert result == [5., 15.]
    assert output_space.contains(result)
示例#2
0
def test_filter():
    filter = BoxMasking(10, 20)

    # passing an output space that is wrong
    with pytest.raises(ValueError):
        filter.validate_output_action_space(DiscreteActionSpace(10))

    # 1 dimensional box
    output_space = BoxActionSpace(1, 5, 30)
    input_space = filter.get_unfiltered_action_space(output_space)

    action = np.array([2])
    result = filter.filter(action)
    assert result == np.array([12])
    assert output_space.contains(result)