def get_dqn( env, lr, optimizer_, num_rollouts, num_iter, tau, target_update_frequency, function_approximation="tabular", *args, **kwargs, ): """Get DQN agent.""" critic = get_default_q_function(env, function_approximation) critic.tau = tau policy = EpsGreedy(critic, ExponentialDecay(0.1, 0, 1000)) optimizer = optimizer_(critic.parameters(), lr=lr) memory = ExperienceReplay(max_len=50000, num_steps=0) return DQNAgent( critic=critic, policy=policy, optimizer=optimizer, memory=memory, reset_memory_after_learn=True, train_frequency=0, num_iter=num_iter, target_update_frequency=target_update_frequency, num_rollouts=num_rollouts, *args, **kwargs, )
def create_er_from_transitions(discrete, dim_state, dim_action, max_len, num_steps, num_transitions): """Create a memory with `num_transitions' transitions.""" if discrete: num_states, num_actions = dim_state, dim_action dim_state, dim_action = (), () else: num_states, num_actions = -1, -1 dim_state, dim_action = (dim_state, ), (dim_action, ) memory = ExperienceReplay(max_len, num_steps=num_steps) for _ in range(num_transitions): observation = Observation.random_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) memory.append(observation) return memory
def create_er_from_episodes(discrete, max_len, num_steps, num_episodes, episode_length): """Rollout an environment and return an Experience Replay Buffer.""" if discrete: env = GymEnvironment("NChain-v0") transformations = [] else: env = GymEnvironment("Pendulum-v0") transformations = [ MeanFunction(lambda state_, action_: state_), StateNormalizer(), ActionNormalizer(), RewardClipper(), ] memory = ExperienceReplay(max_len, transformations=transformations, num_steps=num_steps) for _ in range(num_episodes): state = env.reset() for _ in range(episode_length): action = env.action_space.sample() # sample a random action. observation, state, done, info = step_env(env, state, action, action_scale=1.0) memory.append(observation) memory.end_episode() return memory
def get_vmpo(env, optimizer_, lr, function_approximation="tabular", *args, **kwargs): """Get VMPO agent.""" critic = get_default_value_function(env, function_approximation) policy = get_default_policy(env, function_approximation) memory = ExperienceReplay(max_len=50000, num_steps=0) optimizer = optimizer_(chain(critic.parameters(), policy.parameters()), lr) return VMPOAgent( policy=policy, critic=critic, optimizer=optimizer, memory=memory, *args, **kwargs, )
def get_reps_parametric(env, optimizer_, lr, function_approximation="tabular", *args, **kwargs): """Get Parametric REPS agent.""" critic = get_default_value_function(env, function_approximation) policy = get_default_policy(env, function_approximation) optimizer = optimizer_(chain(critic.parameters(), policy.parameters()), lr=lr) memory = ExperienceReplay(max_len=50000, num_steps=0) return REPSAgent( critic=critic, policy=policy, optimizer=optimizer, memory=memory, *args, **kwargs, )
def test_policies(environment, policy): environment = GymEnvironment(environment, SEED) critic = NNQFunction( dim_state=environment.dim_observation, dim_action=environment.dim_action, num_states=environment.num_states, num_actions=environment.num_actions, layers=LAYERS, tau=TARGET_UPDATE_TAU, ) policy = policy(critic, 0.1) optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE) criterion = torch.nn.MSELoss memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE) agent = DDQNAgent( critic=critic, policy=policy, criterion=criterion, optimizer=optimizer, memory=memory, batch_size=BATCH_SIZE, target_update_frequency=TARGET_UPDATE_FREQUENCY, gamma=GAMMA, ) train_agent( agent, environment, num_episodes=NUM_EPISODES, max_steps=MAX_STEPS, plot_flag=False, ) evaluate_agent(agent, environment, 1, MAX_STEPS, render=False) agent.logger.delete_directory() # Cleanup directory.
def test_tabular_interaction(agent, policy): LEARNING_RATE = 0.1 environment = EasyGridWorld() critic = TabularQFunction(num_states=environment.num_states, num_actions=environment.num_actions) policy = policy(critic, 0.1) optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE) criterion = torch.nn.MSELoss memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE) agent = agent( critic=critic, policy=policy, criterion=criterion, optimizer=optimizer, memory=memory, batch_size=BATCH_SIZE, target_update_frequency=TARGET_UPDATE_FREQUENCY, gamma=GAMMA, ) train_agent( agent, environment, num_episodes=NUM_EPISODES, max_steps=MAX_STEPS, plot_flag=False, ) evaluate_agent(agent, environment, 1, MAX_STEPS, render=False) agent.logger.delete_directory() # Cleanup directory. torch.testing.assert_allclose( critic.table.shape, torch.Size([environment.num_actions, environment.num_states]), )
def test_append_error(self): memory = ExperienceReplay(max_len=100) with pytest.raises(TypeError): memory.append((1, 2, 3, 4, 5))
policy = EpsGreedy(q_function, ExponentialDecay(EPS_START, EPS_END, EPS_DECAY)) optimizer = torch.optim.Adam(q_function.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) criterion = torch.nn.MSELoss if MEMORY == "PER": memory = PrioritizedExperienceReplay(max_len=MEMORY_MAX_SIZE, beta=LinearGrowth(0.8, 1.0, 0.001)) elif MEMORY == "EXP3": memory = EXP3ExperienceReplay(max_len=MEMORY_MAX_SIZE, alpha=0.001, beta=0.1) elif MEMORY == "ER": memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE, num_steps=0) else: raise NotImplementedError(f"{MEMORY} not implemented.") agent = DDQNAgent( critic=q_function, policy=policy, criterion=criterion, optimizer=optimizer, memory=memory, num_iter=1, train_frequency=1, batch_size=BATCH_SIZE, target_update_frequency=TARGET_UPDATE_FREQUENCY, gamma=GAMMA, clip_gradient_val=1.0,
) policy = Policy(q_function, ExponentialDecay(start=1.0, end=0.01, decay=500)) q_target = NNQFunction( dim_state=environment.dim_observation, dim_action=environment.dim_action, num_states=environment.num_states, num_actions=environment.num_actions, layers=LAYERS, tau=TARGET_UPDATE_TAU, ) optimizer = torch.optim.Adam(q_function.parameters(), lr=LEARNING_RATE) criterion = torch.nn.MSELoss memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE) agent = DDQNAgent( critic=q_function, policy=policy, criterion=criterion, optimizer=optimizer, memory=memory, batch_size=BATCH_SIZE, target_update_frequency=TARGET_UPDATE_FREQUENCY, gamma=GAMMA, ) rollout_agent(environment, agent, num_episodes=NUM_EPISODES, max_steps=MAX_STEPS)