Esempio n. 1
0
def test_policy_iteration():
    environment = EasyGridWorld()
    GAMMA = 0.9
    EPS = 1e-3
    policy, value_function = policy_iteration(environment, GAMMA, eps=EPS)

    torch.testing.assert_allclose(value_function.table,
                                  torch.tensor([OPTIMAL_VALUE]),
                                  atol=0.05,
                                  rtol=EPS)
    pred_p = policy.table.argmax(dim=0)
    assert_policy_equality(environment, GAMMA, value_function, OPTIMAL_POLICY,
                           pred_p)

    environment = EasyGridWorld(terminal_states=[22])
    GAMMA = 0.9
    EPS = 1e-3
    policy, value_function = policy_iteration(environment, GAMMA, eps=EPS)

    torch.testing.assert_allclose(
        value_function.table,
        torch.tensor([OPTIMAL_VALUE_WITH_TERMINAL]),
        atol=0.05,
        rtol=EPS,
    )

    pred_p = policy.table.argmax(dim=0)
    assert_policy_equality(environment, GAMMA, value_function,
                           OPTIMAL_POLICY_WITH_TERMINAL, pred_p)
Esempio n. 2
0
def test_tabular_interaction(agent, policy):
    LEARNING_RATE = 0.1
    environment = EasyGridWorld()

    critic = TabularQFunction(num_states=environment.num_states,
                              num_actions=environment.num_actions)
    policy = policy(critic, 0.1)
    optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE)
    criterion = torch.nn.MSELoss

    agent = agent(
        critic=critic,
        policy=policy,
        criterion=criterion,
        optimizer=optimizer,
        target_update_frequency=TARGET_UPDATE_FREQUENCY,
        gamma=GAMMA,
    )

    train_agent(
        agent,
        environment,
        num_episodes=NUM_EPISODES,
        max_steps=MAX_STEPS,
        plot_flag=False,
    )
    evaluate_agent(agent, environment, 1, MAX_STEPS, render=False)
    agent.logger.delete_directory()  # Cleanup directory.
Esempio n. 3
0
def test_rollout_easy_grid_world():
    environment = EasyGridWorld()
    agent = RandomAgent.default(environment)
    rollout_agent(environment, agent, max_steps=20)

    policy = agent.policy
    rollout_policy(environment, policy)
Esempio n. 4
0
def test_linear_system_policy_evaluation():
    environment = EasyGridWorld()
    GAMMA = 0.9
    EPS = 1e-3

    policy = RandomPolicy(
        dim_state=(),
        dim_action=(),
        num_states=environment.num_states,
        num_actions=environment.num_actions,
    )
    value_function = linear_system_policy_evaluation(policy, environment,
                                                     GAMMA)

    torch.testing.assert_allclose(value_function.table,
                                  torch.tensor([RANDOM_VALUE]),
                                  atol=0.05,
                                  rtol=EPS)
def test_tabular_interaction(agent, policy):
    LEARNING_RATE = 0.1
    environment = EasyGridWorld()

    critic = TabularQFunction(num_states=environment.num_states,
                              num_actions=environment.num_actions)
    policy = policy(critic, 0.1)
    optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE)
    criterion = torch.nn.MSELoss
    memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE)

    agent = agent(
        critic=critic,
        policy=policy,
        criterion=criterion,
        optimizer=optimizer,
        memory=memory,
        batch_size=BATCH_SIZE,
        target_update_frequency=TARGET_UPDATE_FREQUENCY,
        gamma=GAMMA,
    )

    train_agent(
        agent,
        environment,
        num_episodes=NUM_EPISODES,
        max_steps=MAX_STEPS,
        plot_flag=False,
    )
    evaluate_agent(agent, environment, 1, MAX_STEPS, render=False)
    agent.logger.delete_directory()  # Cleanup directory.

    torch.testing.assert_allclose(
        critic.table.shape,
        torch.Size([environment.num_actions, environment.num_states]),
    )
Esempio n. 6
0
"""Tabular planning experiments."""
from rllib.algorithms.tabular_planning import (
    iterative_policy_evaluation,
    linear_system_policy_evaluation,
    policy_iteration,
    value_iteration,
)
from rllib.environment.mdps import EasyGridWorld
from rllib.policy import AbstractPolicy, RandomPolicy

environment = EasyGridWorld()
GAMMA = 0.9
EPS = 1e-6

policy = RandomPolicy(
    dim_state=(),
    dim_action=(),
    num_states=environment.num_states,
    num_actions=environment.num_actions,
)  # type: AbstractPolicy

print("Iterative Policy Evaluation:")
value_function = iterative_policy_evaluation(policy,
                                             environment,
                                             GAMMA,
                                             eps=EPS)
print(value_function.table)
print()

print("Linear System Policy Evaluation:")
value_function = linear_system_policy_evaluation(policy, environment, GAMMA)