def test_backtracks_because_of_model_loop(avoid_loops, expected_action): # 0, action 0 -> 1 (high reward) # 1 -> 0 (loop = low value because of penalty) # 0, action 1 -> 2 # 2 passes: first to expand the root, second to expand the left branch, # and backpropagate the loop penalty. # Should choose 0 or 1 depending on the loop avoidance flag. env = testing.TabularEnv( init_state=0, n_actions=2, transitions={ # Root. 0: {0: (1, 1, False), 1: (2, 0, True)}, # Loop in the left branch. 1: {0: (0, 0, False), 1: (0, 0, False)}, }, ) agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=2, discount=1, rate_new_leaves_fn=functools.partial( rate_new_leaves_tabular, state_values={0: 0, 1: 0, 2: 0}, ), graph_mode=True, avoid_loops=avoid_loops, loop_penalty=-2, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (action, _) = testing.run_without_suspensions(agent.act(observation)) assert action == expected_action
def test_integration_with_cartpole(): env = envs.CartPole() agent = agents.StochasticMCTSAgent( n_passes=2, new_leaf_rater_class=functools.partial( agents.stochastic_mcts.RolloutNewLeafRater, rollout_time_limit=2, ), ) episode = testing.run_without_suspensions(agent.solve(env)) assert episode.transition_batch.observation.shape[0] # pylint: disable=no-member
def test_backtracks_because_of_value(): # 0, action 0 -> 1 (medium value) # 0, action 1 -> 2 (high value) # 2, action 0 -> 3 (very low value) # 2, action 1 -> 3 (very low value) # 2 passes, should choose 0. env = testing.TabularEnv( init_state=0, n_actions=2, transitions={ # Root. 0: { 0: (1, 0, False), 1: (2, 0, False) }, # Left branch, ending here. 1: { 0: (3, 0, True), 1: (4, 0, True) }, # Right branch, one more level. 2: { 0: (5, 0, False), 1: (6, 0, False) }, # End of the right branch. 5: { 0: (7, 0, True), 1: (8, 0, True) }, 6: { 0: (9, 0, True), 1: (10, 0, True) }, }, ) agent = agents.StochasticMCTSAgent( n_passes=2, new_leaf_rater_class=functools.partial( TabularNewLeafRater, state_values={ 0: 0, 1: 0, 2: 1, 5: -10, 6: -10, }, ), ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (action, _) = testing.run_without_suspensions(agent.act(observation)) assert action == 0
def test_integration_with_cartpole(graph_mode): env = envs.CartPole() agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=2, rate_new_leaves_fn=functools.partial( agents.stochastic_mcts.rate_new_leaves_with_rollouts, rollout_time_limit=2, ), graph_mode=graph_mode, ) episode = testing.run_without_suspensions(agent.solve(env)) assert episode.transition_batch.observation.shape[0] # pylint: disable=no-member
def test_caches_values_in_graph_mode(graph_mode, expected_second_action): # 0, action 0 -> 1 (high value) # 1, action 0 -> 2 (very low value) # 1, action 1 -> 3 (medium value) # 3, action 0 -> 4 (very low value) # 3, action 1 -> 5 (very low value) # 0, action 1 -> 6 (medium value) # 6, action 0 -> 1 (high value) # 6, action 1 -> 7 (medium value) # 3 passes for the first and 2 for the second action. In graph mode, should # choose 1, then 1. Not in graph mode, should choose 1, then 0. env = testing.TabularEnv( init_state=0, n_actions=2, transitions={ # Root. 0: {0: (1, 0, False), 1: (6, 0, False)}, # Left branch, long one. 1: {0: (2, 0, True), 1: (3, 0, False)}, 3: {0: (4, 0, True), 1: (5, 0, True)}, # Right branch, short, with a connection to the left. 6: {0: (1, 0, False), 1: (7, 0, True)}, }, ) agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=3, rate_new_leaves_fn=functools.partial( rate_new_leaves_tabular, state_values={ 0: 0, 1: 1, 2: -10, 3: 0, 4: -10, 5: -10, 6: 0, 7: 0, }, ), graph_mode=graph_mode, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (first_action, _) = testing.run_without_suspensions(agent.act(observation)) assert first_action == 1 agent.n_passes = 2 (observation, _, _, _) = env.step(first_action) (second_action, _) = testing.run_without_suspensions(agent.act(observation)) assert second_action == expected_second_action
def test_backtracks_because_of_reward(): # 0, action 0 -> 1 (high value, very low reward) # 0, action 1 -> 2 (medium value) # 2 passes, should choose 1. (env, new_leaf_rater_class) = make_one_level_binary_tree( left_value=1, left_reward=-10, right_value=0, right_reward=0 ) agent = agents.StochasticMCTSAgent( n_passes=2, new_leaf_rater_class=new_leaf_rater_class, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (action, _) = testing.run_without_suspensions(agent.act(observation)) assert action == 1
def test_act_doesnt_change_env_state(graph_mode, rate_new_leaves_fn): env = envs.CartPole() agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=2, rate_new_leaves_fn=rate_new_leaves_fn, graph_mode=graph_mode, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) state_before = env.clone_state() testing.run_with_dummy_network(agent.act(observation)) state_after = env.clone_state() np.testing.assert_equal(state_before, state_after)
def test_act_doesnt_change_env_state(new_leaf_rater_class): env = envs.CartPole() agent = agents.StochasticMCTSAgent( n_passes=2, new_leaf_rater_class=new_leaf_rater_class, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) state_before = env.clone_state() network_sig = agent.network_signature(env.observation_space, env.action_space) testing.run_with_dummy_network_prediction(agent.act(observation), network_sig) state_after = env.clone_state() np.testing.assert_equal(state_before, state_after)
def test_backtracks_because_of_reward(graph_mode): # 0, action 0 -> 1 (high value, very low reward) # 0, action 1 -> 2 (medium value) # 2 passes, should choose 1. (env, rate_new_leaves_fn) = make_one_level_binary_tree( left_value=1, left_reward=-10, right_value=0, right_reward=0 ) agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=2, rate_new_leaves_fn=rate_new_leaves_fn, graph_mode=graph_mode, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (action, _) = testing.run_without_suspensions(agent.act(observation)) assert action == 1
def test_stops_on_done(): # 0 -> 1 (done) # 2 passes, env is not stepped from 1. env = testing.TabularEnv( init_state=0, n_actions=1, transitions={0: {0: (1, 0, True)}}, ) agent = agents.StochasticMCTSAgent( n_passes=2, new_leaf_rater_class=functools.partial( TabularNewLeafRater, state_values={0: 0, 1: 0}, ), ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) # rate_new_leaves_fn errors out when rating nodes not in the value table. testing.run_without_suspensions(agent.act(observation))
def test_callbacks_are_called(): mock_callback = mock.MagicMock() mock_callback_class = lambda: mock_callback (env, new_leaf_rater_class) = make_one_level_binary_tree(left_value=0, right_value=0) agent = agents.StochasticMCTSAgent( n_passes=2, new_leaf_rater_class=new_leaf_rater_class, callback_classes=[mock_callback_class], ) testing.run_without_suspensions(agent.solve(env)) mock_callback.on_episode_begin.assert_called_once() assert mock_callback.on_pass_begin.call_count == 4 assert mock_callback.on_model_step.call_count == 3 assert mock_callback.on_pass_end.call_count == 4 assert mock_callback.on_real_step.call_count == 2 mock_callback.on_episode_end.assert_called_once()
def test_decision_after_two_passes( left_value, right_value, left_reward, right_reward, expected_action, ): # 0, action 0 -> 1 (left) # 0, action 1 -> 2 (right) # 2 passes, should choose depending on qualities. (env, new_leaf_rater_class) = make_one_level_binary_tree( left_value, right_value, left_reward, right_reward ) agent = agents.StochasticMCTSAgent( n_passes=2, new_leaf_rater_class=new_leaf_rater_class, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (actual_action, _) = testing.run_without_suspensions(agent.act(observation)) assert actual_action == expected_action
def test_chooses_something_in_dead_end(): # 0 -> 0 # 2 passes: first to expand the root, second to merge child with the root. # Should choose 0 and not error out. env = testing.TabularEnv( init_state=0, n_actions=1, transitions={0: {0: (0, 0, False)}}, ) agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=2, rate_new_leaves_fn=functools.partial( rate_new_leaves_tabular, state_values={0: 0, 1: 0}, ), graph_mode=True, avoid_loops=True, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (action, _) = testing.run_without_suspensions(agent.act(observation)) assert action == 0
def test_avoids_real_loops(avoid_loops, expected_action): # 0, action 0 -> 0 (high reward) # 0, action 1 -> 1 (done) # 2 passes: first to expand the root, second to merge child with the root. # Should choose 0 or 1 depending on the loop avoidance flag. env = testing.TabularEnv( init_state=0, n_actions=2, transitions={0: {0: (0, 1, False), 1: (1, 0, True)}}, ) agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=2, rate_new_leaves_fn=functools.partial( rate_new_leaves_tabular, state_values={0: 0, 1: 0}, ), graph_mode=True, avoid_loops=avoid_loops, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (action, _) = testing.run_without_suspensions(agent.act(observation)) assert action == expected_action
def test_decision_after_one_pass( left_value, right_value, left_reward, right_reward, expected_action, graph_mode, ): # 0, action 0 -> 1 (left) # 0, action 1 -> 2 (right) # 1 pass, should choose depending on qualities. (env, rate_new_leaves_fn) = make_one_level_binary_tree( left_value, right_value, left_reward, right_reward ) agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=1, rate_new_leaves_fn=rate_new_leaves_fn, graph_mode=graph_mode, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) (actual_action, _) = testing.run_without_suspensions(agent.act(observation)) assert actual_action == expected_action