コード例 #1
0
def test_full_trial():
    env = StrictTMazeEnv()
    obs = env.reset()
    for _ in range(3):
        assert (obs == [1, 0, 1, 0]).all()
        assert env.direction == 0
        obs, reward, done, _ = env.step(1)
        assert not done
        assert reward == 0
    assert (obs == [0, 1, 0, 0]).all()
    assert env.direction == 0
    assert reward == 0
    obs, reward, done, _ = env.step(2)
    assert env.direction == 1
    assert (obs == [1, 0, 0, 0]).all()
    assert reward == 0
    assert not done
    for _ in range(2):
        obs, reward, done, _ = env.step(1)
        assert env.direction == 1
        assert (obs == [1, 0, 1, 0]).all()
        assert reward == 0
        assert not done
    obs, reward, done, _ = env.step(1)
    assert (obs == [1, 1, 1, 1]).all()
    assert reward == 1
    assert env.direction == 1
    assert not done
    obs, reward, done, _ = env.step(2)
    assert reward == 0
    assert (obs == [1, 0, 1, 0]).all()
    assert env.direction == 0
    assert env.row_pos == env.col_pos == 4
    assert not done
コード例 #2
0
def test_cross_turn_penalty():
    env = StrictTMazeEnv()
    obs = env.reset()
    for _ in range(3):
        assert (obs == [1, 0, 1, 0]).all()
        assert env.direction == 0
        obs, reward, done, _ = env.step(1)
        assert not done
        assert reward == 0
    assert (obs == [0, 1, 0, 0]).all()
    assert env.direction == 0
    assert reward == 0
    obs, reward, done, _ = env.step(2)
    assert env.direction == 1
    assert (obs == [1, 0, 0, 0]).all()
    assert reward == 0
    assert not done
    obs, reward, done, _ = env.step(1)
    assert (obs == [1, 0, 1, 0]).all()
    assert env.direction == 1
    assert reward == 0
    assert not done
    obs, reward, done, _ = env.step(2)
    assert env.direction == 2
    assert (obs == [0, 1, 0, 0]).all()
    assert reward == -0.4
    assert not done
    obs, reward, done, _ = env.step(1)
    assert (obs == [1, 0, 1, 0]).all()
    assert env.row_pos == env.col_pos == 4
    assert env.direction == 0
コード例 #3
0
def test_step_with_reset():
    env = StrictTMazeEnv()
    obs = env.reset()
    assert obs.shape == (4, )
    assert env.row_pos == env.col_pos == 4
    assert (obs == [1, 0, 1, 0]).all()

    obs, reward, done, _ = env.step(0)
    assert (obs == [1, 1, 0, 0]).all()
    assert reward == -0.4
    assert not done

    obs, reward, done, _ = env.step(0)
    assert (obs == [1, 0, 1, 0]).all()
    assert reward == 0.0
    assert not done
コード例 #4
0
def test_reward_flip():
    env = StrictTMazeEnv(n_trials=10, reward_flip_mean=5, reward_flip_range=3)
    for _ in range(5):
        obs = env.reset()
        for i in range(10):
            for _ in range(3):
                assert (obs == [1, 0, 1, 0]).all()
                assert env.direction == 0
                obs, reward, done, _ = env.step(1)
                assert not done
                assert reward == 0
            assert (obs == [0, 1, 0, 0]).all()
            assert env.direction == 0
            assert reward == 0
            obs, reward, done, _ = env.step(2)
            assert env.direction == 1
            assert (obs == [1, 0, 0, 0]).all()
            assert reward == 0
            assert not done
            for _ in range(2):
                obs, reward, done, _ = env.step(1)
                assert env.direction == 1
                assert (obs == [1, 0, 1, 0]).all()
                assert reward == 0
                assert not done
            obs, reward, done, _ = env.step(1)
            assert (obs[:-1] == [1, 1, 1]).all()
            assert reward == obs[-1]
            assert reward in {0.2, 1.0}
            if i < 2:
                assert reward == 1.0
            elif i > 8:
                assert reward == 0.2
            assert env.direction == 1
            assert not done
            obs, reward, done, _ = env.step(2)
            assert reward == 0
            assert (obs == [1, 0, 1, 0]).all()
            assert env.direction == 0
            assert env.row_pos == env.col_pos == 4
        assert done
コード例 #5
0
def test_step_without_reset():
    env = StrictTMazeEnv()
    with pytest.raises(AssertionError):
        env.step(1)