コード例 #1
0
def test_discretizer_undiscretize_accept_plain_numbers(max_difference_lin):
    bins = 5
    box = Box(low=np.array([1.2]), high=np.array([1.5]))
    obs = [1.3333]
    d = Discretizer(box, num_bins=bins, log_bins=False)
    disc = d.discretize(obs)
    undisc_obs = d.undiscretize(int(disc))
    assert (abs(obs - undisc_obs) <= max_difference_lin(box, bins)).all()
コード例 #2
0
def test_discretizer_log_outside(box, obs_center, obs_out):
    d = Discretizer(box, num_bins=5, log_bins=True)
    disc = d.discretize(obs_out)
    undisc_obs = d.undiscretize(disc)
    for i, elem in enumerate(undisc_obs):
        if obs_out[i] < obs_center[i]:
            assert elem == pytest.approx(box.low[i])
        else:
            assert elem == pytest.approx(box.high[i])
コード例 #3
0
def test_discretizer_log(box, obs_center):
    d = Discretizer(box, num_bins=5, log_bins=True)
    disc = d.discretize(obs_center)
    undisc_obs = d.undiscretize(disc)
    for i, elem in enumerate(undisc_obs):
        if obs_center[i] == 0:
            assert elem == pytest.approx(0, abs=1.0)
        else:
            assert elem >= box.low[i]
            assert elem <= box.high[i]
コード例 #4
0
def test_discretizer_discrete():
    discrete_space = Discrete(4 * 7)
    d = Discretizer(discrete_space, num_bins=5, log_bins=False)
    obs = np.array([17])
    disc_arr = d.discretize(obs)
    undisc_obs_arr = d.undiscretize(disc_arr)
    assert disc_arr == np.array(obs)
    assert undisc_obs_arr == np.array(obs)

    # Also accept plain numbers:
    obs = 17
    disc_arr = d.discretize(obs)
    undisc_obs_arr = d.undiscretize(disc_arr)
    assert disc_arr == np.array(obs)
    assert undisc_obs_arr == np.array(obs)

    # Undiscretize should also accept plain numbers:
    undisc_obs = d.undiscretize(17)
    assert undisc_obs == np.array(17)
コード例 #5
0
def test_discretizer_lin(box, obs, max_difference_lin):
    bins = 5
    d = Discretizer(box, num_bins=bins, log_bins=False)
    disc = d.discretize(obs)
    undisc_obs = d.undiscretize(disc)
    assert (abs(obs - undisc_obs) <= max_difference_lin(box, bins)).all()
コード例 #6
0
def test_discretizer_unsupported_space(box, disc):
    dict_space = Dict(my_box=box, my_disc=disc)
    with pytest.raises(TypeError):
        _ = Discretizer(dict_space, num_bins=5, log_bins=False)
コード例 #7
0
def test_discretizer_wrong_shape(box, invalid_obs):
    d = Discretizer(box, num_bins=5, log_bins=False)
    with pytest.raises(ValueError, match="shape"):
        _ = d.discretize(invalid_obs)