def test_backtracks_because_of_model_loop(avoid_loops, expected_action): # 0, action 0 -> 1 (high reward) # 1 -> 0 (loop = low value because of penalty) # 0, action 1 -> 2 # 2 passes: first to expand the root, second to expand the left branch, # and backpropagate the loop penalty. # Should choose 0 or 1 depending on the loop avoidance flag. env = testing.TabularEnv( init_state=0, n_actions=2, transitions={ # Root. 0: {0: (1, 1, False), 1: (2, 0, True)}, # Loop in the left branch. 1: {0: (0, 0, False), 1: (0, 0, False)}, }, ) agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=2, discount=1, rate_new_leaves_fn=functools.partial( rate_new_leaves_tabular, state_values={0: 0, 1: 0, 2: 0}, ), graph_mode=True, avoid_loops=avoid_loops, loop_penalty=-2, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (action, _) = testing.run_without_suspensions(agent.act(observation)) assert action == expected_action
def make_one_level_binary_tree( left_value, right_value, left_reward=0, right_reward=0 ): """Makes a TabularEnv and new_leaf_rater_class for a 1-level binary tree.""" # 0, action 0 -> 1 (left) # 0, action 1 -> 2 (right) (root_state, left_state, right_state) = (0, 1, 2) env = testing.TabularEnv( init_state=root_state, n_actions=2, transitions={ # state: {action: (state', reward, done)} root_state: { 0: (left_state, left_reward, False), 1: (right_state, right_reward, False), }, # Dummy terminal states, made so we can expand left and right. left_state: {0: (3, 0, True), 1: (4, 0, True)}, right_state: {0: (5, 0, True), 1: (6, 0, True)}, } ) new_leaf_rater_class = functools.partial( TabularNewLeafRater, state_values={ root_state: 0, left_state: left_value, right_state: right_value, # Dummy terminal states. 3: 0, 4: 0, 5: 0, 6: 0, }, ) return (env, new_leaf_rater_class)
def test_backtracks_because_of_value(): # 0, action 0 -> 1 (medium value) # 0, action 1 -> 2 (high value) # 2, action 0 -> 3 (very low value) # 2, action 1 -> 3 (very low value) # 2 passes, should choose 0. env = testing.TabularEnv( init_state=0, n_actions=2, transitions={ # Root. 0: { 0: (1, 0, False), 1: (2, 0, False) }, # Left branch, ending here. 1: { 0: (3, 0, True), 1: (4, 0, True) }, # Right branch, one more level. 2: { 0: (5, 0, False), 1: (6, 0, False) }, # End of the right branch. 5: { 0: (7, 0, True), 1: (8, 0, True) }, 6: { 0: (9, 0, True), 1: (10, 0, True) }, }, ) agent = agents.StochasticMCTSAgent( n_passes=2, new_leaf_rater_class=functools.partial( TabularNewLeafRater, state_values={ 0: 0, 1: 0, 2: 1, 5: -10, 6: -10, }, ), ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (action, _) = testing.run_without_suspensions(agent.act(observation)) assert action == 0
def test_caches_values_in_graph_mode(graph_mode, expected_second_action): # 0, action 0 -> 1 (high value) # 1, action 0 -> 2 (very low value) # 1, action 1 -> 3 (medium value) # 3, action 0 -> 4 (very low value) # 3, action 1 -> 5 (very low value) # 0, action 1 -> 6 (medium value) # 6, action 0 -> 1 (high value) # 6, action 1 -> 7 (medium value) # 3 passes for the first and 2 for the second action. In graph mode, should # choose 1, then 1. Not in graph mode, should choose 1, then 0. env = testing.TabularEnv( init_state=0, n_actions=2, transitions={ # Root. 0: {0: (1, 0, False), 1: (6, 0, False)}, # Left branch, long one. 1: {0: (2, 0, True), 1: (3, 0, False)}, 3: {0: (4, 0, True), 1: (5, 0, True)}, # Right branch, short, with a connection to the left. 6: {0: (1, 0, False), 1: (7, 0, True)}, }, ) agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=3, rate_new_leaves_fn=functools.partial( rate_new_leaves_tabular, state_values={ 0: 0, 1: 1, 2: -10, 3: 0, 4: -10, 5: -10, 6: 0, 7: 0, }, ), graph_mode=graph_mode, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (first_action, _) = testing.run_without_suspensions(agent.act(observation)) assert first_action == 1 agent.n_passes = 2 (observation, _, _, _) = env.step(first_action) (second_action, _) = testing.run_without_suspensions(agent.act(observation)) assert second_action == expected_second_action
def test_stops_on_done(): # 0 -> 1 (done) # 2 passes, env is not stepped from 1. env = testing.TabularEnv( init_state=0, n_actions=1, transitions={0: {0: (1, 0, True)}}, ) agent = agents.StochasticMCTSAgent( n_passes=2, new_leaf_rater_class=functools.partial( TabularNewLeafRater, state_values={0: 0, 1: 0}, ), ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) # rate_new_leaves_fn errors out when rating nodes not in the value table. testing.run_without_suspensions(agent.act(observation))
def test_chooses_something_in_dead_end(): # 0 -> 0 # 2 passes: first to expand the root, second to merge child with the root. # Should choose 0 and not error out. env = testing.TabularEnv( init_state=0, n_actions=1, transitions={0: {0: (0, 0, False)}}, ) agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=2, rate_new_leaves_fn=functools.partial( rate_new_leaves_tabular, state_values={0: 0, 1: 0}, ), graph_mode=True, avoid_loops=True, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (action, _) = testing.run_without_suspensions(agent.act(observation)) assert action == 0
def test_avoids_real_loops(avoid_loops, expected_action): # 0, action 0 -> 0 (high reward) # 0, action 1 -> 1 (done) # 2 passes: first to expand the root, second to merge child with the root. # Should choose 0 or 1 depending on the loop avoidance flag. env = testing.TabularEnv( init_state=0, n_actions=2, transitions={0: {0: (0, 1, False), 1: (1, 0, True)}}, ) agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=2, rate_new_leaves_fn=functools.partial( rate_new_leaves_tabular, state_values={0: 0, 1: 0}, ), graph_mode=True, avoid_loops=avoid_loops, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (action, _) = testing.run_without_suspensions(agent.act(observation)) assert action == expected_action