def test_iterate_value_q_pi_multi_threaded(): thread_manager = RunThreadManager(True) def train_thread_target(): random_state = RandomState(12345) mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None) q_S_A = TabularStateActionValueEstimator(mdp_environment, 0.1, None) mdp_agent = StochasticMdpAgent('test', random_state, q_S_A.get_initial_policy(), 1) iterate_value_q_pi(agent=mdp_agent, environment=mdp_environment, num_improvements=1000000, num_episodes_per_improvement=10, num_updates_per_improvement=None, alpha=0.1, mode=Mode.SARSA, n_steps=None, planning_environment=None, make_final_policy_greedy=False, q_S_A=q_S_A, thread_manager=thread_manager, num_improvements_per_plot=10) # premature update should have no effect assert update_policy_iteration_plot() is None # initialize plot from main thread plot_policy_iteration(iteration_average_reward=[], iteration_total_states=[], iteration_num_states_improved=[], elapsed_seconds_average_rewards={}, pdf=None) # run training thread run_thread = Thread(target=train_thread_target) run_thread.start() time.sleep(1) # update plot asynchronously update_policy_iteration_plot() time.sleep(1) # should be allowed to update plot from non-main thread def bad_update(): with pytest.raises(ValueError, match='Can only update plot on main thread.'): update_policy_iteration_plot() bad_thread = Thread(target=bad_update) bad_thread.start() bad_thread.join() thread_manager.abort = True run_thread.join()
def bad_update(): with pytest.raises(ValueError, match='Can only update plot on main thread.'): update_policy_iteration_plot()