def test_compute_free_energy(seq, actions, expected): """Tests private method _compute_free_energy()""" env = Lattice2DEnv(seq) for action in actions: env.step(action) result = env._compute_free_energy(env.state) assert expected == result
def test_get_adjacent_coords(): """Tests private method _get_adjacent_coords()""" seq = generate_sequence(10) env = Lattice2DEnv(seq) test_coords = (0, 0) result = env._get_adjacent_coords(test_coords) expected = {0: (-1, 0), 1: (0, -1), 2: (0, 1), 3: (1, 0)} assert result == expected
def test_draw_grid(actions, expected_coords): """Tests private method _draw_grid()""" seq = 'HH' env = Lattice2DEnv(seq) result = env._draw_grid(env.state) env.step(actions) expected = env.grid expected[expected_coords] = 1 assert np.array_equal(expected, result)
def test_compute_reward_with_trap(): """Test reward function when agent is trapped""" seq = 'H' * 20 # sequence of 20 Hs env = Lattice2DEnv(seq) expected_reward = 3 - (len(seq) * env.trap_penalty ) # (12 bonds) - (20 * 0.5) # Define sequence of actions that will trap the agent actions = [0, 2, 2, 3, 3, 1, 0, 1] for _, action in enumerate(actions): _, reward, done, _ = env.step(action) if done: assert expected_reward == reward
def test_trapped(): """Test that trapped is set as soon as the agent becomes trapped""" env = Lattice2DEnv("PPPPPPPPPP") # has 0 reward for action in [0, 0, 2, 2, 3, 3, 1]: _, _, done, info = env.step(action) assert not done assert not info["is_trapped"] _, reward, done, info = env.step(0) assert done assert info["is_trapped"] assert reward == -5 # len(seq) * trap_penalty
def lattice2d_fixed_env(): """Lattice2DEnv with a fixed sequence""" seq = 'HHHH' return Lattice2DEnv(seq)
def lattice2d_env(): """Lattice2DEnv with a random sequence""" seq = generate_sequence(10) return Lattice2DEnv(seq)
def test_init_illegal_trap_penalty(penalty): """Exception must be raised when illegal penalty is given""" with pytest.raises((ValueError, TypeError)): seq = generate_sequence(10) Lattice2DEnv(seq, trap_penalty=penalty)
def test_init_invalid_sequence(sequence): """Exception must be raised with invalid input""" with pytest.raises((ValueError, AttributeError)): Lattice2DEnv(sequence)
def test_done_seq_length_one(): """Test that the done signal is set when starting with a sequence of length 1""" env = Lattice2DEnv("H") assert env.done
Spyder Editor This is a temporary script file. """ from gym import spaces from gym_lattice.envs import Lattice2DEnv import numpy as np import random np.random.seed(42) seq = 'HPhP' # Our input sequence action_space = spaces.Discrete(4) # Choose among [0, 1, 2 ,3] N_EPISODES = 100 MAX_EPISODE_STEPS = len(seq) env = Lattice2DEnv(seq) MIN_ALPHA = 0.0001 alphas = np.linspace(1.0, MIN_ALPHA, N_EPISODES) MIN_EPSILON = 0.05 epsilons = np.linspace(0.9, MIN_EPSILON, N_EPISODES) gamma = 0.95 q_table = dict() ################# Functions ################# def choose_action(state, eps): if random.uniform(0, 1) < eps:
def lattice2d_fixed_env(): """Lattice2DEnv with a fixed sequence""" from gym_lattice.envs import Lattice2DEnv seq = 'HHHH' return Lattice2DEnv(seq)
def lattice2d_env(): """Lattice2DEnv with a random sequence""" from gym_lattice.envs import Lattice2DEnv seq = generate_sequence(10) return Lattice2DEnv(seq)
def test_init_illegal_collision_penalty(penalty): """Exception must be raised when illegal penalty is given""" with pytest.raises((ValueError, TypeError)): from gym_lattice.envs import Lattice2DEnv seq = generate_sequence(10) Lattice2DEnv(seq, collision_penalty=penalty)
@author: Hengameh """ from gym_lattice.envs import Lattice2DEnv from gym import spaces import numpy as np import random np.random.seed(42) p = [8, 4, 6, 6] # number and length of operators action_space = spaces.Discrete(5) # Choose among [0, 1, 2 , 3, 4] N_EPISODES = 1000 MAX_EPISODE_STEPS = sum(p) env = Lattice2DEnv(p) MIN_ALPHA = 0.0001 alphas = np.linspace(1.0, MIN_ALPHA, N_EPISODES) MIN_EPSILON = 0.0 epsilons = np.linspace(0.9, MIN_EPSILON, N_EPISODES) gamma = 0.95 q_table = dict() ################# Q Learning Functions ################# def choose_action(state, eps): if random.uniform(0, 1) < eps: