Ejemplo n.º 1
0
def test_iterate_value_q_pi_multi_threaded():

    thread_manager = RunThreadManager(True)

    def train_thread_target():

        random_state = RandomState(12345)

        mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

        q_S_A = TabularStateActionValueEstimator(mdp_environment, 0.1, None)

        mdp_agent = StochasticMdpAgent('test', random_state,
                                       q_S_A.get_initial_policy(), 1)

        iterate_value_q_pi(agent=mdp_agent,
                           environment=mdp_environment,
                           num_improvements=1000000,
                           num_episodes_per_improvement=10,
                           num_updates_per_improvement=None,
                           alpha=0.1,
                           mode=Mode.SARSA,
                           n_steps=None,
                           planning_environment=None,
                           make_final_policy_greedy=False,
                           q_S_A=q_S_A,
                           thread_manager=thread_manager,
                           num_improvements_per_plot=10)

    # premature update should have no effect
    assert update_policy_iteration_plot() is None

    # initialize plot from main thread
    plot_policy_iteration(iteration_average_reward=[],
                          iteration_total_states=[],
                          iteration_num_states_improved=[],
                          elapsed_seconds_average_rewards={},
                          pdf=None)

    # run training thread
    run_thread = Thread(target=train_thread_target)
    run_thread.start()
    time.sleep(1)

    # update plot asynchronously
    update_policy_iteration_plot()
    time.sleep(1)

    # should be allowed to update plot from non-main thread
    def bad_update():
        with pytest.raises(ValueError,
                           match='Can only update plot on main thread.'):
            update_policy_iteration_plot()

    bad_thread = Thread(target=bad_update)
    bad_thread.start()
    bad_thread.join()

    thread_manager.abort = True
    run_thread.join()
Ejemplo n.º 2
0
 def bad_update():
     with pytest.raises(ValueError,
                        match='Can only update plot on main thread.'):
         update_policy_iteration_plot()