예제 #1
0
def test_step_with_reset():
    env = TMazeEnv()
    obs = env.reset()
    assert obs.shape == (4,)
    assert env.row_pos == env.col_pos == 4
    assert (obs == [1, 0, 1, 0]).all()
    obs, reward, done, _ = env.step(0)
    assert (obs == [1, 0, 1, 0]).all()
    assert reward == -0.4
    assert not done
예제 #2
0
def test_reward_flip():
    env = TMazeEnv(n_trials=10, reward_flip_mean=5, reward_flip_range=3)
    for _ in range(10):
        obs = env.reset()
        for i in range(10):
            for _ in range(3):
                assert (obs == [1, 0, 1, 0]).all()
                obs, reward, done, _ = env.step(1)
                assert not done
                assert reward == 0
            for _ in range(3):
                assert (obs == [0, 1, 0, 0]).all()
                assert reward == 0
                obs, reward, done, _ = env.step(2)
                assert not done
            assert (obs[:-1] == [0, 1, 1]).all()
            assert reward == obs[-1]
            assert reward in {0.2, 1.0}
            if i < 2:
                assert reward == 1.0
            elif i > 8:
                assert reward == 0.2
            obs, reward, done, _ = env.step(2)
            assert reward == 0
            assert (obs == [1, 0, 1, 0]).all()
            assert env.row_pos == env.col_pos == 4
        assert done
예제 #3
0
def test_default_initialization():
    env = TMazeEnv()
    assert env.hall_len == 3
    assert env.n_trials == 100
    assert env.maze.shape == (6, 9)
    print(env.maze)
    assert (
        env.maze
        == [
            [1, 1, 1, 1, 1, 1, 1, 1, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 1, 1, 1, 0, 1, 1, 1, 1],
            [1, 1, 1, 1, 0, 1, 1, 1, 1],
            [1, 1, 1, 1, 0, 1, 1, 1, 1],
            [1, 1, 1, 1, 1, 1, 1, 1, 1],
        ]
    ).all()
예제 #4
0
def test_init_reward_side():
    env = TMazeEnv(init_reward_side=0)
    obs = env.reset()
    for _ in range(3):
        assert (obs == [1, 0, 1, 0]).all()
        obs, reward, done, _ = env.step(1)
        assert not done
        assert reward == 0
    for _ in range(3):
        assert (obs == [0, 1, 0, 0]).all()
        assert reward == 0
        obs, reward, done, _ = env.step(0)
        assert not done
    assert (obs == [1, 1, 0, 1]).all()
    assert reward == 1
    obs, reward, done, _ = env.step(1)
    assert reward == 0
    assert (obs == [1, 0, 1, 0]).all()
    assert env.row_pos == env.col_pos == 4
    assert not done
예제 #5
0
def test_full_trial():
    env = TMazeEnv()
    obs = env.reset()
    for _ in range(3):
        assert (obs == [1, 0, 1, 0]).all()
        obs, reward, done, _ = env.step(1)
        assert not done
        assert reward == 0
    for _ in range(3):
        assert (obs == [0, 1, 0, 0]).all()
        assert reward == 0
        obs, reward, done, _ = env.step(2)
        assert not done
    assert (obs == [0, 1, 1, 1]).all()
    assert reward == 1
    obs, reward, done, _ = env.step(2)
    assert reward == 0
    assert (obs == [1, 0, 1, 0]).all()
    assert env.row_pos == env.col_pos == 4
    assert not done
예제 #6
0
def test_low_reward():
    env = TMazeEnv()
    obs = env.reset()
    for _ in range(3):
        assert (obs == [1, 0, 1, 0]).all()
        obs, reward, done, _ = env.step(1)
        assert not done
        assert reward == 0
    for _ in range(3):
        assert (obs == [0, 1, 0, 0]).all()
        assert reward == 0
        obs, reward, done, _ = env.step(0)
        assert not done
    assert (obs == [1, 1, 0, 0.2]).all()
    assert reward == 0.2
    obs, reward, done, _ = env.step(1)
    assert reward == 0
    assert (obs == [1, 0, 1, 0]).all()
    assert env.row_pos == env.col_pos == 4
    assert not done
예제 #7
0
def test_render():
    env = TMazeEnv()
    with pytest.raises(NotImplementedError):
        env.render()
예제 #8
0
def test_step_without_reset():
    env = TMazeEnv()
    with pytest.raises(AssertionError):
        env.step(1)