def test_full_trial(): env = StrictTMazeEnv() obs = env.reset() for _ in range(3): assert (obs == [1, 0, 1, 0]).all() assert env.direction == 0 obs, reward, done, _ = env.step(1) assert not done assert reward == 0 assert (obs == [0, 1, 0, 0]).all() assert env.direction == 0 assert reward == 0 obs, reward, done, _ = env.step(2) assert env.direction == 1 assert (obs == [1, 0, 0, 0]).all() assert reward == 0 assert not done for _ in range(2): obs, reward, done, _ = env.step(1) assert env.direction == 1 assert (obs == [1, 0, 1, 0]).all() assert reward == 0 assert not done obs, reward, done, _ = env.step(1) assert (obs == [1, 1, 1, 1]).all() assert reward == 1 assert env.direction == 1 assert not done obs, reward, done, _ = env.step(2) assert reward == 0 assert (obs == [1, 0, 1, 0]).all() assert env.direction == 0 assert env.row_pos == env.col_pos == 4 assert not done
def test_cross_turn_penalty(): env = StrictTMazeEnv() obs = env.reset() for _ in range(3): assert (obs == [1, 0, 1, 0]).all() assert env.direction == 0 obs, reward, done, _ = env.step(1) assert not done assert reward == 0 assert (obs == [0, 1, 0, 0]).all() assert env.direction == 0 assert reward == 0 obs, reward, done, _ = env.step(2) assert env.direction == 1 assert (obs == [1, 0, 0, 0]).all() assert reward == 0 assert not done obs, reward, done, _ = env.step(1) assert (obs == [1, 0, 1, 0]).all() assert env.direction == 1 assert reward == 0 assert not done obs, reward, done, _ = env.step(2) assert env.direction == 2 assert (obs == [0, 1, 0, 0]).all() assert reward == -0.4 assert not done obs, reward, done, _ = env.step(1) assert (obs == [1, 0, 1, 0]).all() assert env.row_pos == env.col_pos == 4 assert env.direction == 0
def test_step_with_reset(): env = StrictTMazeEnv() obs = env.reset() assert obs.shape == (4, ) assert env.row_pos == env.col_pos == 4 assert (obs == [1, 0, 1, 0]).all() obs, reward, done, _ = env.step(0) assert (obs == [1, 1, 0, 0]).all() assert reward == -0.4 assert not done obs, reward, done, _ = env.step(0) assert (obs == [1, 0, 1, 0]).all() assert reward == 0.0 assert not done
def test_reward_flip(): env = StrictTMazeEnv(n_trials=10, reward_flip_mean=5, reward_flip_range=3) for _ in range(5): obs = env.reset() for i in range(10): for _ in range(3): assert (obs == [1, 0, 1, 0]).all() assert env.direction == 0 obs, reward, done, _ = env.step(1) assert not done assert reward == 0 assert (obs == [0, 1, 0, 0]).all() assert env.direction == 0 assert reward == 0 obs, reward, done, _ = env.step(2) assert env.direction == 1 assert (obs == [1, 0, 0, 0]).all() assert reward == 0 assert not done for _ in range(2): obs, reward, done, _ = env.step(1) assert env.direction == 1 assert (obs == [1, 0, 1, 0]).all() assert reward == 0 assert not done obs, reward, done, _ = env.step(1) assert (obs[:-1] == [1, 1, 1]).all() assert reward == obs[-1] assert reward in {0.2, 1.0} if i < 2: assert reward == 1.0 elif i > 8: assert reward == 0.2 assert env.direction == 1 assert not done obs, reward, done, _ = env.step(2) assert reward == 0 assert (obs == [1, 0, 1, 0]).all() assert env.direction == 0 assert env.row_pos == env.col_pos == 4 assert done