예제 #1
0
def test_observation_when_raw_step_returns_incorrect_no_of_observations():
    """Test that a ServiceError is propagated when raw_step() returns unexpected
    number of observations."""
    def make_failing_raw_step(n: int):
        def failing_raw_step(*args, **kwargs):
            """A callback that returns done=True."""
            del args  # Unused
            del kwargs  # Unused
            return ["ir"] * n, None, False, {}

        return failing_raw_step

    spaces = [
        ObservationSpace(
            name="ir",
            space=Space(int64_value=Int64Range(min=0)),
        )
    ]

    observation = ObservationView(make_failing_raw_step(0), spaces)
    with pytest.raises(
            ServiceError,
            match=r"^Expected 1 'ir' observation but the service returned 0$"):
        observation["ir"]

    observation = ObservationView(make_failing_raw_step(3), spaces)
    with pytest.raises(
            ServiceError,
            match=r"^Expected 1 'ir' observation but the service returned 3$"):
        observation["ir"]
예제 #2
0
def test_observation_when_raw_step_returns_done():
    """Test that a SessionNotFoundError from the raw_step() callback propagates as a"""
    def make_failing_raw_step(error_msg=None):
        def failing_raw_step(*args, **kwargs):
            """A callback that returns done=True."""
            info = {}
            if error_msg:
                info["error_details"] = error_msg
            return [], None, True, info

        return failing_raw_step

    spaces = [
        ObservationSpace(
            name="ir",
            space=Space(int64_value=Int64Range(min=0)),
        )
    ]

    observation = ObservationView(make_failing_raw_step(), spaces)
    with pytest.raises(ServiceError,
                       match=r"^Failed to compute observation 'ir'$"):
        observation["ir"]  # pylint: disable=pointless-statement

    observation = ObservationView(make_failing_raw_step("Oh no!"), spaces)
    with pytest.raises(ServiceError,
                       match=r"^Failed to compute observation 'ir': Oh no!$"):
        observation["ir"]  # pylint: disable=pointless-statement
예제 #3
0
def test_register_derived_space():
    spaces = [
        ObservationSpace(
            name="ir",
            string_size_range=ScalarRange(min=ScalarLimit(value=0)),
        ),
    ]
    mock = MockGetObservation(
        ret=[Observation(string_value="Hello, world!")],
    )
    observation = ObservationView(mock, spaces)
    observation.register_derived_space(
        base_name="ir",
        derived_name="ir_len",
        derived_space=Box(low=0, high=float("inf"), shape=(1,), dtype=int),
        cb=lambda base: [
            len(base),
        ],
    )

    value = observation["ir_len"]
    assert isinstance(value, list)
    assert value == [
        len("Hello, world!"),
    ]
예제 #4
0
def test_invalid_observation_index():
    spaces = [
        ObservationSpace(
            name="ir",
            string_size_range=ScalarRange(min=ScalarLimit(value=0)),
        )
    ]
    observation = ObservationView(MockGetObservation(), spaces)
    with pytest.raises(KeyError):
        _ = observation[100]
예제 #5
0
def test_observed_value_types():
    spaces = [
        ObservationSpace(
            name="ir",
            string_size_range=ScalarRange(min=ScalarLimit(value=0)),
        ),
        ObservationSpace(
            name="features",
            int64_range_list=ScalarRangeList(range=[
                ScalarRange(min=ScalarLimit(value=-100),
                            max=ScalarLimit(value=100)),
                ScalarRange(min=ScalarLimit(value=-100),
                            max=ScalarLimit(value=100)),
            ]),
        ),
        ObservationSpace(
            name="dfeat",
            double_range_list=ScalarRangeList(range=[
                ScalarRange(min=ScalarLimit(value=0.5),
                            max=ScalarLimit(value=2.5))
            ]),
        ),
        ObservationSpace(
            name="binary",
            binary_size_range=ScalarRange(min=ScalarLimit(value=5),
                                          max=ScalarLimit(value=5)),
        ),
    ]
    mock = MockGetObservation(ret=[
        Observation(string_value="Hello, IR"),
        Observation(double_list=DoubleList(value=[1.0, 2.0])),
        Observation(int64_list=Int64List(value=[-5, 15])),
        Observation(binary_value=b"Hello, bytes\0"),
    ])
    observation = ObservationView(mock, spaces)

    value = observation["ir"]
    assert isinstance(value, str)
    assert value == "Hello, IR"

    value = observation["dfeat"]
    np.testing.assert_array_almost_equal(value, [1.0, 2.0])
    assert value.dtype == np.float64

    value = observation["features"]
    np.testing.assert_array_equal(value, [-5, 15])
    assert value.dtype == np.int64

    value = observation["binary"]
    assert value == b"Hello, bytes\0"

    # Check that the correct observation_space_list indices were used.
    assert mock.called_observation_spaces == [0, 2, 1, 3]
예제 #6
0
def test_invalid_observation_name():
    spaces = [
        ObservationSpace(
            name="ir",
            string_size_range=ScalarRange(min=ScalarLimit(value=0)),
        )
    ]
    observation = ObservationView(MockGetObservation(), spaces)
    with pytest.raises(KeyError) as ctx:
        _ = observation["invalid"]

    assert str(ctx.value) == "'invalid'"
예제 #7
0
def test_observed_value_types():
    spaces = [
        ObservationSpace(
            name="ir",
            space=Space(string_value=StringSpace(length_range=Int64Range(
                min=0))),
        ),
        ObservationSpace(
            name="features",
            space=Space(int64_box=Int64Box(
                low=Int64Tensor(shape=[2], value=[-100, -100]),
                high=Int64Tensor(shape=[2], value=[100, 100]),
            ), ),
        ),
        ObservationSpace(
            name="dfeat",
            space=Space(double_box=DoubleBox(
                low=DoubleTensor(shape=[1], value=[0.5]),
                high=DoubleTensor(shape=[1], value=[2.5]),
            ), ),
        ),
        ObservationSpace(
            name="binary",
            space=Space(int64_value=Int64Range(min=5, max=5)),
        ),
    ]
    mock = MockRawStep(ret=[
        "Hello, IR",
        [1.0, 2.0],
        [-5, 15],
        b"Hello, bytes\0",
        "Hello, IR",
        [1.0, 2.0],
        [-5, 15],
        b"Hello, bytes\0",
    ])
    observation = ObservationView(mock, spaces)

    value = observation["ir"]
    assert isinstance(value, str)
    assert value == "Hello, IR"

    value = observation["dfeat"]
    np.testing.assert_array_almost_equal(value, [1.0, 2.0])

    value = observation["features"]
    np.testing.assert_array_equal(value, [-5, 15])

    value = observation["binary"]
    assert value == b"Hello, bytes\0"

    # Check that the correct observation_space_list indices were used.
    assert mock.called_observation_spaces == [
        "ir", "dfeat", "features", "binary"
    ]
    mock.called_observation_spaces = []

    # Repeat the above tests using the generated bound methods.
    value = observation.ir()
    assert isinstance(value, str)
    assert value == "Hello, IR"

    value = observation.dfeat()
    np.testing.assert_array_almost_equal(value, [1.0, 2.0])

    value = observation.features()
    np.testing.assert_array_equal(value, [-5, 15])

    value = observation.binary()
    assert value == b"Hello, bytes\0"

    # Check that the correct observation_space_list indices were used.
    assert mock.called_observation_spaces == [
        "ir", "dfeat", "features", "binary"
    ]
예제 #8
0
def test_empty_space():
    with pytest.raises(ValueError) as ctx:
        ObservationView(MockRawStep(), [])
    assert str(ctx.value) == "No observation spaces"
예제 #9
0
def test_observed_value_types():
    spaces = [
        ObservationSpace(
            name="ir",
            string_size_range=ScalarRange(min=ScalarLimit(value=0)),
        ),
        ObservationSpace(
            name="features",
            int64_range_list=ScalarRangeList(range=[
                ScalarRange(min=ScalarLimit(value=-100),
                            max=ScalarLimit(value=100)),
                ScalarRange(min=ScalarLimit(value=-100),
                            max=ScalarLimit(value=100)),
            ]),
        ),
        ObservationSpace(
            name="dfeat",
            double_range_list=ScalarRangeList(range=[
                ScalarRange(min=ScalarLimit(value=0.5),
                            max=ScalarLimit(value=2.5))
            ]),
        ),
        ObservationSpace(
            name="binary",
            binary_size_range=ScalarRange(min=ScalarLimit(value=5),
                                          max=ScalarLimit(value=5)),
        ),
    ]
    mock = MockRawStep(ret=[
        "Hello, IR",
        [1.0, 2.0],
        [-5, 15],
        b"Hello, bytes\0",
        "Hello, IR",
        [1.0, 2.0],
        [-5, 15],
        b"Hello, bytes\0",
    ])
    observation = ObservationView(mock, spaces)

    value = observation["ir"]
    assert isinstance(value, str)
    assert value == "Hello, IR"

    value = observation["dfeat"]
    np.testing.assert_array_almost_equal(value, [1.0, 2.0])

    value = observation["features"]
    np.testing.assert_array_equal(value, [-5, 15])

    value = observation["binary"]
    assert value == b"Hello, bytes\0"

    # Check that the correct observation_space_list indices were used.
    assert mock.called_observation_spaces == [
        "ir", "dfeat", "features", "binary"
    ]
    mock.called_observation_spaces = []

    # Repeat the above tests using the generated bound methods.
    value = observation.ir()
    assert isinstance(value, str)
    assert value == "Hello, IR"

    value = observation.dfeat()
    np.testing.assert_array_almost_equal(value, [1.0, 2.0])

    value = observation.features()
    np.testing.assert_array_equal(value, [-5, 15])

    value = observation.binary()
    assert value == b"Hello, bytes\0"

    # Check that the correct observation_space_list indices were used.
    assert mock.called_observation_spaces == [
        "ir", "dfeat", "features", "binary"
    ]