示例#1
0
def test_tabular_interaction(agent, policy):
    LEARNING_RATE = 0.1
    environment = EasyGridWorld()

    critic = TabularQFunction(num_states=environment.num_states,
                              num_actions=environment.num_actions)
    policy = policy(critic, 0.1)
    optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE)
    criterion = torch.nn.MSELoss

    agent = agent(
        critic=critic,
        policy=policy,
        criterion=criterion,
        optimizer=optimizer,
        target_update_frequency=TARGET_UPDATE_FREQUENCY,
        gamma=GAMMA,
    )

    train_agent(
        agent,
        environment,
        num_episodes=NUM_EPISODES,
        max_steps=MAX_STEPS,
        plot_flag=False,
    )
    evaluate_agent(agent, environment, 1, MAX_STEPS, render=False)
    agent.logger.delete_directory()  # Cleanup directory.
示例#2
0
    def test_forward(self, num_states, num_actions, batch_size):
        q_function = TabularQFunction(num_states=num_states, num_actions=num_actions)

        state = random_tensor(True, num_states, batch_size)
        action = random_tensor(True, num_actions, batch_size)
        value = q_function(state, action)
        assert value.shape == torch.Size([batch_size] if batch_size else [])
        assert value.dtype is torch.get_default_dtype()
示例#3
0
    def test_partial_q_function(self, num_states, num_actions, batch_size):
        q_function = TabularQFunction(num_states=num_states, num_actions=num_actions)
        state = random_tensor(True, num_states, batch_size)

        action_value = q_function(state)
        assert action_value.shape == torch.Size(
            [batch_size, num_actions] if batch_size else [num_actions]
        )
        assert action_value.dtype is torch.get_default_dtype()
示例#4
0
def get_default_q_function(environment, function_approximation):
    """Get default Q-Function."""
    if function_approximation == "tabular":
        q_function = TabularQFunction.default(environment)
    elif function_approximation == "linear":
        q_function = NNQFunction.default(environment, layers=[200])
        freeze_hidden_layers(q_function)
    else:
        q_function = NNQFunction.default(environment)
    return q_function
def test_tabular_interaction(agent, policy):
    LEARNING_RATE = 0.1
    environment = EasyGridWorld()

    critic = TabularQFunction(num_states=environment.num_states,
                              num_actions=environment.num_actions)
    policy = policy(critic, 0.1)
    optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE)
    criterion = torch.nn.MSELoss
    memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE)

    agent = agent(
        critic=critic,
        policy=policy,
        criterion=criterion,
        optimizer=optimizer,
        memory=memory,
        batch_size=BATCH_SIZE,
        target_update_frequency=TARGET_UPDATE_FREQUENCY,
        gamma=GAMMA,
    )

    train_agent(
        agent,
        environment,
        num_episodes=NUM_EPISODES,
        max_steps=MAX_STEPS,
        plot_flag=False,
    )
    evaluate_agent(agent, environment, 1, MAX_STEPS, render=False)
    agent.logger.delete_directory()  # Cleanup directory.

    torch.testing.assert_allclose(
        critic.table.shape,
        torch.Size([environment.num_actions, environment.num_states]),
    )
示例#6
0
 def test_set_value(self):
     value_function = TabularQFunction(num_states=4, num_actions=2)
     value_function.set_value(2, 1, 1.0)
     torch.testing.assert_allclose(
         value_function.table, torch.tensor([[0, 0, 0.0, 0], [0, 0, 1.0, 0]])
     )
示例#7
0
 def test_compile(self):
     torch.jit.script(TabularQFunction(num_states=4, num_actions=2))
示例#8
0
 def test_init(self):
     value_function = TabularQFunction(num_states=4, num_actions=2)
     torch.testing.assert_allclose(value_function.table, torch.zeros(2, 4))