def test_tabular_interaction(agent, policy): LEARNING_RATE = 0.1 environment = EasyGridWorld() critic = TabularQFunction(num_states=environment.num_states, num_actions=environment.num_actions) policy = policy(critic, 0.1) optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE) criterion = torch.nn.MSELoss agent = agent( critic=critic, policy=policy, criterion=criterion, optimizer=optimizer, target_update_frequency=TARGET_UPDATE_FREQUENCY, gamma=GAMMA, ) train_agent( agent, environment, num_episodes=NUM_EPISODES, max_steps=MAX_STEPS, plot_flag=False, ) evaluate_agent(agent, environment, 1, MAX_STEPS, render=False) agent.logger.delete_directory() # Cleanup directory.
def test_forward(self, num_states, num_actions, batch_size): q_function = TabularQFunction(num_states=num_states, num_actions=num_actions) state = random_tensor(True, num_states, batch_size) action = random_tensor(True, num_actions, batch_size) value = q_function(state, action) assert value.shape == torch.Size([batch_size] if batch_size else []) assert value.dtype is torch.get_default_dtype()
def test_partial_q_function(self, num_states, num_actions, batch_size): q_function = TabularQFunction(num_states=num_states, num_actions=num_actions) state = random_tensor(True, num_states, batch_size) action_value = q_function(state) assert action_value.shape == torch.Size( [batch_size, num_actions] if batch_size else [num_actions] ) assert action_value.dtype is torch.get_default_dtype()
def get_default_q_function(environment, function_approximation): """Get default Q-Function.""" if function_approximation == "tabular": q_function = TabularQFunction.default(environment) elif function_approximation == "linear": q_function = NNQFunction.default(environment, layers=[200]) freeze_hidden_layers(q_function) else: q_function = NNQFunction.default(environment) return q_function
def test_tabular_interaction(agent, policy): LEARNING_RATE = 0.1 environment = EasyGridWorld() critic = TabularQFunction(num_states=environment.num_states, num_actions=environment.num_actions) policy = policy(critic, 0.1) optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE) criterion = torch.nn.MSELoss memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE) agent = agent( critic=critic, policy=policy, criterion=criterion, optimizer=optimizer, memory=memory, batch_size=BATCH_SIZE, target_update_frequency=TARGET_UPDATE_FREQUENCY, gamma=GAMMA, ) train_agent( agent, environment, num_episodes=NUM_EPISODES, max_steps=MAX_STEPS, plot_flag=False, ) evaluate_agent(agent, environment, 1, MAX_STEPS, render=False) agent.logger.delete_directory() # Cleanup directory. torch.testing.assert_allclose( critic.table.shape, torch.Size([environment.num_actions, environment.num_states]), )
def test_set_value(self): value_function = TabularQFunction(num_states=4, num_actions=2) value_function.set_value(2, 1, 1.0) torch.testing.assert_allclose( value_function.table, torch.tensor([[0, 0, 0.0, 0], [0, 0, 1.0, 0]]) )
def test_compile(self): torch.jit.script(TabularQFunction(num_states=4, num_actions=2))
def test_init(self): value_function = TabularQFunction(num_states=4, num_actions=2) torch.testing.assert_allclose(value_function.table, torch.zeros(2, 4))