def test_iterate_value_q_pi(): random_state = RandomState(12345) mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None) q_S_A = TabularStateActionValueEstimator(mdp_environment, 0.1, None) mdp_agent = StochasticMdpAgent('test', random_state, q_S_A.get_initial_policy(), 1) iterate_value_q_pi(agent=mdp_agent, environment=mdp_environment, num_improvements=3000, num_episodes_per_improvement=1, update_upon_every_visit=False, planning_environment=None, make_final_policy_greedy=False, q_S_A=q_S_A) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_monte_carlo_iteration_of_value_q_pi.pickle', 'wb') as file: # pickle.dump((mdp_agent.pi, q_S_A), file) with open( f'{os.path.dirname(__file__)}/fixtures/test_monte_carlo_iteration_of_value_q_pi.pickle', 'rb') as file: pi_fixture, q_S_A_fixture = pickle.load(file) assert tabular_pi_legacy_eq(mdp_agent.pi, pi_fixture) and tabular_estimator_legacy_eq( q_S_A, q_S_A_fixture)
def test_learn(): random_state = RandomState(12345) gym = Gym(random_state=random_state, T=None, gym_id='CartPole-v1') q_S_A = TabularStateActionValueEstimator(gym, 0.05, 0.001) mdp_agent = StochasticMdpAgent('agent', random_state, q_S_A.get_initial_policy(), 1) iterate_value_q_pi(agent=mdp_agent, environment=gym, num_improvements=10, num_episodes_per_improvement=100, num_updates_per_improvement=None, alpha=0.1, mode=Mode.SARSA, n_steps=1, planning_environment=None, make_final_policy_greedy=False, q_S_A=q_S_A) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_gym.pickle', 'wb') as file: # pickle.dump((mdp_agent.pi, q_S_A), file) with open(f'{os.path.dirname(__file__)}/fixtures/test_gym.pickle', 'rb') as file: fixture_pi, fixture_q_S_A = pickle.load(file) assert tabular_pi_legacy_eq(mdp_agent.pi, fixture_pi) and tabular_estimator_legacy_eq( q_S_A, fixture_q_S_A)
def test_sarsa_iterate_value_q_pi_with_trajectory_planning(): random_state = RandomState(12345) mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None) q_S_A = TabularStateActionValueEstimator(mdp_environment, 0.05, None) mdp_agent = ActionValueMdpAgent('test', random_state, 1, q_S_A) planning_environment = TrajectorySamplingMdpPlanningEnvironment( 'test planning', random_state, StochasticEnvironmentModel(), 10, None) iterate_value_q_pi(agent=mdp_agent, environment=mdp_environment, num_improvements=100, num_episodes_per_improvement=1, num_updates_per_improvement=None, alpha=0.1, mode=Mode.SARSA, n_steps=1, planning_environment=planning_environment, make_final_policy_greedy=True) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_td_iteration_of_value_q_pi_planning.pickle', 'wb') as file: # pickle.dump((mdp_agent.pi, q_S_A), file) with open( f'{os.path.dirname(__file__)}/fixtures/test_td_iteration_of_value_q_pi_planning.pickle', 'rb') as file: pi_fixture, q_S_A_fixture = pickle.load(file) assert tabular_pi_legacy_eq(mdp_agent.pi, pi_fixture) and tabular_estimator_legacy_eq( q_S_A, q_S_A_fixture)
def test_n_step_q_learning_iterate_value_q_pi(): random_state = RandomState(12345) mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None) q_S_A = TabularStateActionValueEstimator(mdp_environment, 0.05, None) mdp_agent = StochasticMdpAgent('test', random_state, q_S_A.get_initial_policy(), 1) iterate_value_q_pi(agent=mdp_agent, environment=mdp_environment, num_improvements=10, num_episodes_per_improvement=100, num_updates_per_improvement=None, alpha=0.1, mode=Mode.Q_LEARNING, n_steps=3, planning_environment=None, make_final_policy_greedy=False, q_S_A=q_S_A) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_td_n_step_q_learning_iteration_of_value_q_pi.pickle', 'wb') as file: # pickle.dump((mdp_agent.pi, q_S_A), file) with open( f'{os.path.dirname(__file__)}/fixtures/test_td_n_step_q_learning_iteration_of_value_q_pi.pickle', 'rb') as file: fixture_pi, fixture_q_S_A = pickle.load(file) assert tabular_pi_legacy_eq(mdp_agent.pi, fixture_pi) and tabular_estimator_legacy_eq( q_S_A, fixture_q_S_A)
def test_learn(): random_state = RandomState(12345) mancala: Mancala = Mancala(random_state=random_state, T=None, initial_count=4, player_2=None) p1 = ActionValueMdpAgent( 'player 1', random_state, 1, TabularStateActionValueEstimator(mancala, 0.05, None)) checkpoint_path = iterate_value_q_pi( agent=p1, environment=mancala, num_improvements=3, num_episodes_per_improvement=100, update_upon_every_visit=False, planning_environment=None, make_final_policy_greedy=False, num_improvements_per_checkpoint=3, checkpoint_path=tempfile.NamedTemporaryFile(delete=False).name) # uncomment the following line and run test to update fixture # with open(f'{os.path.dirname(__file__)}/fixtures/test_mancala.pickle', 'wb') as file: # pickle.dump(p1.pi, file) with open(f'{os.path.dirname(__file__)}/fixtures/test_mancala.pickle', 'rb') as file: fixture = pickle.load(file) assert tabular_pi_legacy_eq(p1.pi, fixture) resumed_p1 = resume_from_checkpoint(checkpoint_path=checkpoint_path, resume_function=iterate_value_q_pi, num_improvements=2) # run same number of improvements without checkpoint...result should be the same. random_state = RandomState(12345) mancala: Mancala = Mancala(random_state=random_state, T=None, initial_count=4, player_2=None) no_checkpoint_p1 = ActionValueMdpAgent( 'player 1', random_state, 1, TabularStateActionValueEstimator(mancala, 0.05, None)) iterate_value_q_pi(agent=no_checkpoint_p1, environment=mancala, num_improvements=5, num_episodes_per_improvement=100, update_upon_every_visit=False, planning_environment=None, make_final_policy_greedy=False) assert no_checkpoint_p1.pi == resumed_p1.pi