def test_policy_iteration(): environment = EasyGridWorld() GAMMA = 0.9 EPS = 1e-3 policy, value_function = policy_iteration(environment, GAMMA, eps=EPS) torch.testing.assert_allclose(value_function.table, torch.tensor([OPTIMAL_VALUE]), atol=0.05, rtol=EPS) pred_p = policy.table.argmax(dim=0) assert_policy_equality(environment, GAMMA, value_function, OPTIMAL_POLICY, pred_p) environment = EasyGridWorld(terminal_states=[22]) GAMMA = 0.9 EPS = 1e-3 policy, value_function = policy_iteration(environment, GAMMA, eps=EPS) torch.testing.assert_allclose( value_function.table, torch.tensor([OPTIMAL_VALUE_WITH_TERMINAL]), atol=0.05, rtol=EPS, ) pred_p = policy.table.argmax(dim=0) assert_policy_equality(environment, GAMMA, value_function, OPTIMAL_POLICY_WITH_TERMINAL, pred_p)
def test_tabular_interaction(agent, policy): LEARNING_RATE = 0.1 environment = EasyGridWorld() critic = TabularQFunction(num_states=environment.num_states, num_actions=environment.num_actions) policy = policy(critic, 0.1) optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE) criterion = torch.nn.MSELoss agent = agent( critic=critic, policy=policy, criterion=criterion, optimizer=optimizer, target_update_frequency=TARGET_UPDATE_FREQUENCY, gamma=GAMMA, ) train_agent( agent, environment, num_episodes=NUM_EPISODES, max_steps=MAX_STEPS, plot_flag=False, ) evaluate_agent(agent, environment, 1, MAX_STEPS, render=False) agent.logger.delete_directory() # Cleanup directory.
def test_rollout_easy_grid_world(): environment = EasyGridWorld() agent = RandomAgent.default(environment) rollout_agent(environment, agent, max_steps=20) policy = agent.policy rollout_policy(environment, policy)
def test_linear_system_policy_evaluation(): environment = EasyGridWorld() GAMMA = 0.9 EPS = 1e-3 policy = RandomPolicy( dim_state=(), dim_action=(), num_states=environment.num_states, num_actions=environment.num_actions, ) value_function = linear_system_policy_evaluation(policy, environment, GAMMA) torch.testing.assert_allclose(value_function.table, torch.tensor([RANDOM_VALUE]), atol=0.05, rtol=EPS)
def test_tabular_interaction(agent, policy): LEARNING_RATE = 0.1 environment = EasyGridWorld() critic = TabularQFunction(num_states=environment.num_states, num_actions=environment.num_actions) policy = policy(critic, 0.1) optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE) criterion = torch.nn.MSELoss memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE) agent = agent( critic=critic, policy=policy, criterion=criterion, optimizer=optimizer, memory=memory, batch_size=BATCH_SIZE, target_update_frequency=TARGET_UPDATE_FREQUENCY, gamma=GAMMA, ) train_agent( agent, environment, num_episodes=NUM_EPISODES, max_steps=MAX_STEPS, plot_flag=False, ) evaluate_agent(agent, environment, 1, MAX_STEPS, render=False) agent.logger.delete_directory() # Cleanup directory. torch.testing.assert_allclose( critic.table.shape, torch.Size([environment.num_actions, environment.num_states]), )
"""Tabular planning experiments.""" from rllib.algorithms.tabular_planning import ( iterative_policy_evaluation, linear_system_policy_evaluation, policy_iteration, value_iteration, ) from rllib.environment.mdps import EasyGridWorld from rllib.policy import AbstractPolicy, RandomPolicy environment = EasyGridWorld() GAMMA = 0.9 EPS = 1e-6 policy = RandomPolicy( dim_state=(), dim_action=(), num_states=environment.num_states, num_actions=environment.num_actions, ) # type: AbstractPolicy print("Iterative Policy Evaluation:") value_function = iterative_policy_evaluation(policy, environment, GAMMA, eps=EPS) print(value_function.table) print() print("Linear System Policy Evaluation:") value_function = linear_system_policy_evaluation(policy, environment, GAMMA)