def test_full_trial():
    env = StrictTMazeEnv()
    obs = env.reset()
    for _ in range(3):
        assert (obs == [1, 0, 1, 0]).all()
        assert env.direction == 0
        obs, reward, done, _ = env.step(1)
        assert not done
        assert reward == 0
    assert (obs == [0, 1, 0, 0]).all()
    assert env.direction == 0
    assert reward == 0
    obs, reward, done, _ = env.step(2)
    assert env.direction == 1
    assert (obs == [1, 0, 0, 0]).all()
    assert reward == 0
    assert not done
    for _ in range(2):
        obs, reward, done, _ = env.step(1)
        assert env.direction == 1
        assert (obs == [1, 0, 1, 0]).all()
        assert reward == 0
        assert not done
    obs, reward, done, _ = env.step(1)
    assert (obs == [1, 1, 1, 1]).all()
    assert reward == 1
    assert env.direction == 1
    assert not done
    obs, reward, done, _ = env.step(2)
    assert reward == 0
    assert (obs == [1, 0, 1, 0]).all()
    assert env.direction == 0
    assert env.row_pos == env.col_pos == 4
    assert not done
def test_cross_turn_penalty():
    env = StrictTMazeEnv()
    obs = env.reset()
    for _ in range(3):
        assert (obs == [1, 0, 1, 0]).all()
        assert env.direction == 0
        obs, reward, done, _ = env.step(1)
        assert not done
        assert reward == 0
    assert (obs == [0, 1, 0, 0]).all()
    assert env.direction == 0
    assert reward == 0
    obs, reward, done, _ = env.step(2)
    assert env.direction == 1
    assert (obs == [1, 0, 0, 0]).all()
    assert reward == 0
    assert not done
    obs, reward, done, _ = env.step(1)
    assert (obs == [1, 0, 1, 0]).all()
    assert env.direction == 1
    assert reward == 0
    assert not done
    obs, reward, done, _ = env.step(2)
    assert env.direction == 2
    assert (obs == [0, 1, 0, 0]).all()
    assert reward == -0.4
    assert not done
    obs, reward, done, _ = env.step(1)
    assert (obs == [1, 0, 1, 0]).all()
    assert env.row_pos == env.col_pos == 4
    assert env.direction == 0
def test_step_with_reset():
    env = StrictTMazeEnv()
    obs = env.reset()
    assert obs.shape == (4, )
    assert env.row_pos == env.col_pos == 4
    assert (obs == [1, 0, 1, 0]).all()

    obs, reward, done, _ = env.step(0)
    assert (obs == [1, 1, 0, 0]).all()
    assert reward == -0.4
    assert not done

    obs, reward, done, _ = env.step(0)
    assert (obs == [1, 0, 1, 0]).all()
    assert reward == 0.0
    assert not done
def test_optimal():
    envs = [
        StrictTMazeEnv(init_reward_side=i, n_trials=100) for i in [1, 0, 1, 0]
    ]

    evaluator = MultiEnvEvaluator(make_net,
                                  activate_net,
                                  envs=envs,
                                  batch_size=4,
                                  max_env_steps=1600)

    fitness = evaluator.eval_genome(None, None)
    assert fitness == 98.8
def test_default_initialization():
    env = StrictTMazeEnv()
    assert env.hall_len == 3
    assert env.n_trials == 100
    assert env.maze.shape == (6, 9)
    assert (env.maze == [
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 0, 0, 0, 0, 0, 0, 0, 1],
        [1, 1, 1, 1, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
    ]).all()
def test_reward_flip():
    env = StrictTMazeEnv(n_trials=10, reward_flip_mean=5, reward_flip_range=3)
    for _ in range(5):
        obs = env.reset()
        for i in range(10):
            for _ in range(3):
                assert (obs == [1, 0, 1, 0]).all()
                assert env.direction == 0
                obs, reward, done, _ = env.step(1)
                assert not done
                assert reward == 0
            assert (obs == [0, 1, 0, 0]).all()
            assert env.direction == 0
            assert reward == 0
            obs, reward, done, _ = env.step(2)
            assert env.direction == 1
            assert (obs == [1, 0, 0, 0]).all()
            assert reward == 0
            assert not done
            for _ in range(2):
                obs, reward, done, _ = env.step(1)
                assert env.direction == 1
                assert (obs == [1, 0, 1, 0]).all()
                assert reward == 0
                assert not done
            obs, reward, done, _ = env.step(1)
            assert (obs[:-1] == [1, 1, 1]).all()
            assert reward == obs[-1]
            assert reward in {0.2, 1.0}
            if i < 2:
                assert reward == 1.0
            elif i > 8:
                assert reward == 0.2
            assert env.direction == 1
            assert not done
            obs, reward, done, _ = env.step(2)
            assert reward == 0
            assert (obs == [1, 0, 1, 0]).all()
            assert env.direction == 0
            assert env.row_pos == env.col_pos == 4
        assert done
def test_render():
    env = StrictTMazeEnv()
    with pytest.raises(NotImplementedError):
        env.render()
def test_step_without_reset():
    env = StrictTMazeEnv()
    with pytest.raises(AssertionError):
        env.step(1)