コード例 #1
0
def test_reward_flip():
    env = TMazeEnv(n_trials=10, reward_flip_mean=5, reward_flip_range=3)
    for _ in range(10):
        obs = env.reset()
        for i in range(10):
            for _ in range(3):
                assert (obs == [1, 0, 1, 0]).all()
                obs, reward, done, _ = env.step(1)
                assert not done
                assert reward == 0
            for _ in range(3):
                assert (obs == [0, 1, 0, 0]).all()
                assert reward == 0
                obs, reward, done, _ = env.step(2)
                assert not done
            assert (obs[:-1] == [0, 1, 1]).all()
            assert reward == obs[-1]
            assert reward in {0.2, 1.0}
            if i < 2:
                assert reward == 1.0
            elif i > 8:
                assert reward == 0.2
            obs, reward, done, _ = env.step(2)
            assert reward == 0
            assert (obs == [1, 0, 1, 0]).all()
            assert env.row_pos == env.col_pos == 4
        assert done
コード例 #2
0
def test_step_with_reset():
    env = TMazeEnv()
    obs = env.reset()
    assert obs.shape == (4, )
    assert env.row_pos == env.col_pos == 4
    assert (obs == [1, 0, 1, 0]).all()
    obs, reward, done, _ = env.step(0)
    assert (obs == [1, 0, 1, 0]).all()
    assert reward == -0.4
    assert not done
コード例 #3
0
def test_init_reward_side():
    env = TMazeEnv(init_reward_side=0)
    obs = env.reset()
    for _ in range(3):
        assert (obs == [1, 0, 1, 0]).all()
        obs, reward, done, _ = env.step(1)
        assert not done
        assert reward == 0
    for _ in range(3):
        assert (obs == [0, 1, 0, 0]).all()
        assert reward == 0
        obs, reward, done, _ = env.step(0)
        assert not done
    assert (obs == [1, 1, 0, 1]).all()
    assert reward == 1
    obs, reward, done, _ = env.step(1)
    assert reward == 0
    assert (obs == [1, 0, 1, 0]).all()
    assert env.row_pos == env.col_pos == 4
    assert not done
コード例 #4
0
def test_full_trial():
    env = TMazeEnv()
    obs = env.reset()
    for _ in range(3):
        assert (obs == [1, 0, 1, 0]).all()
        obs, reward, done, _ = env.step(1)
        assert not done
        assert reward == 0
    for _ in range(3):
        assert (obs == [0, 1, 0, 0]).all()
        assert reward == 0
        obs, reward, done, _ = env.step(2)
        assert not done
    assert (obs == [0, 1, 1, 1]).all()
    assert reward == 1
    obs, reward, done, _ = env.step(2)
    assert reward == 0
    assert (obs == [1, 0, 1, 0]).all()
    assert env.row_pos == env.col_pos == 4
    assert not done
コード例 #5
0
def test_low_reward():
    env = TMazeEnv()
    obs = env.reset()
    for _ in range(3):
        assert (obs == [1, 0, 1, 0]).all()
        obs, reward, done, _ = env.step(1)
        assert not done
        assert reward == 0
    for _ in range(3):
        assert (obs == [0, 1, 0, 0]).all()
        assert reward == 0
        obs, reward, done, _ = env.step(0)
        assert not done
    assert (obs == [1, 1, 0, 0.2]).all()
    assert reward == 0.2
    obs, reward, done, _ = env.step(1)
    assert reward == 0
    assert (obs == [1, 0, 1, 0]).all()
    assert env.row_pos == env.col_pos == 4
    assert not done
コード例 #6
0
def test_step_without_reset():
    env = TMazeEnv()
    with pytest.raises(AssertionError):
        env.step(1)