예제 #1
0
def get_dqn(
    env,
    lr,
    optimizer_,
    num_rollouts,
    num_iter,
    tau,
    target_update_frequency,
    function_approximation="tabular",
    *args,
    **kwargs,
):
    """Get DQN agent."""
    critic = get_default_q_function(env, function_approximation)
    critic.tau = tau
    policy = EpsGreedy(critic, ExponentialDecay(0.1, 0, 1000))
    optimizer = optimizer_(critic.parameters(), lr=lr)
    memory = ExperienceReplay(max_len=50000, num_steps=0)
    return DQNAgent(
        critic=critic,
        policy=policy,
        optimizer=optimizer,
        memory=memory,
        reset_memory_after_learn=True,
        train_frequency=0,
        num_iter=num_iter,
        target_update_frequency=target_update_frequency,
        num_rollouts=num_rollouts,
        *args,
        **kwargs,
    )
예제 #2
0
def create_er_from_transitions(discrete, dim_state, dim_action, max_len,
                               num_steps, num_transitions):
    """Create a memory with `num_transitions' transitions."""
    if discrete:
        num_states, num_actions = dim_state, dim_action
        dim_state, dim_action = (), ()
    else:
        num_states, num_actions = -1, -1
        dim_state, dim_action = (dim_state, ), (dim_action, )

    memory = ExperienceReplay(max_len, num_steps=num_steps)
    for _ in range(num_transitions):
        observation = Observation.random_example(
            dim_state=dim_state,
            dim_action=dim_action,
            num_states=num_states,
            num_actions=num_actions,
        )
        memory.append(observation)
    return memory
예제 #3
0
def create_er_from_episodes(discrete, max_len, num_steps, num_episodes,
                            episode_length):
    """Rollout an environment and return an Experience Replay Buffer."""

    if discrete:
        env = GymEnvironment("NChain-v0")
        transformations = []
    else:
        env = GymEnvironment("Pendulum-v0")
        transformations = [
            MeanFunction(lambda state_, action_: state_),
            StateNormalizer(),
            ActionNormalizer(),
            RewardClipper(),
        ]

    memory = ExperienceReplay(max_len,
                              transformations=transformations,
                              num_steps=num_steps)

    for _ in range(num_episodes):
        state = env.reset()
        for _ in range(episode_length):
            action = env.action_space.sample()  # sample a random action.
            observation, state, done, info = step_env(env,
                                                      state,
                                                      action,
                                                      action_scale=1.0)
            memory.append(observation)
        memory.end_episode()

    return memory
예제 #4
0
def get_vmpo(env,
             optimizer_,
             lr,
             function_approximation="tabular",
             *args,
             **kwargs):
    """Get VMPO agent."""
    critic = get_default_value_function(env, function_approximation)
    policy = get_default_policy(env, function_approximation)
    memory = ExperienceReplay(max_len=50000, num_steps=0)
    optimizer = optimizer_(chain(critic.parameters(), policy.parameters()), lr)
    return VMPOAgent(
        policy=policy,
        critic=critic,
        optimizer=optimizer,
        memory=memory,
        *args,
        **kwargs,
    )
예제 #5
0
def get_reps_parametric(env,
                        optimizer_,
                        lr,
                        function_approximation="tabular",
                        *args,
                        **kwargs):
    """Get Parametric REPS agent."""
    critic = get_default_value_function(env, function_approximation)
    policy = get_default_policy(env, function_approximation)
    optimizer = optimizer_(chain(critic.parameters(), policy.parameters()),
                           lr=lr)
    memory = ExperienceReplay(max_len=50000, num_steps=0)
    return REPSAgent(
        critic=critic,
        policy=policy,
        optimizer=optimizer,
        memory=memory,
        *args,
        **kwargs,
    )
예제 #6
0
def test_policies(environment, policy):
    environment = GymEnvironment(environment, SEED)

    critic = NNQFunction(
        dim_state=environment.dim_observation,
        dim_action=environment.dim_action,
        num_states=environment.num_states,
        num_actions=environment.num_actions,
        layers=LAYERS,
        tau=TARGET_UPDATE_TAU,
    )

    policy = policy(critic, 0.1)

    optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE)
    criterion = torch.nn.MSELoss
    memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE)

    agent = DDQNAgent(
        critic=critic,
        policy=policy,
        criterion=criterion,
        optimizer=optimizer,
        memory=memory,
        batch_size=BATCH_SIZE,
        target_update_frequency=TARGET_UPDATE_FREQUENCY,
        gamma=GAMMA,
    )
    train_agent(
        agent,
        environment,
        num_episodes=NUM_EPISODES,
        max_steps=MAX_STEPS,
        plot_flag=False,
    )
    evaluate_agent(agent, environment, 1, MAX_STEPS, render=False)
    agent.logger.delete_directory()  # Cleanup directory.
예제 #7
0
def test_tabular_interaction(agent, policy):
    LEARNING_RATE = 0.1
    environment = EasyGridWorld()

    critic = TabularQFunction(num_states=environment.num_states,
                              num_actions=environment.num_actions)
    policy = policy(critic, 0.1)
    optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE)
    criterion = torch.nn.MSELoss
    memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE)

    agent = agent(
        critic=critic,
        policy=policy,
        criterion=criterion,
        optimizer=optimizer,
        memory=memory,
        batch_size=BATCH_SIZE,
        target_update_frequency=TARGET_UPDATE_FREQUENCY,
        gamma=GAMMA,
    )

    train_agent(
        agent,
        environment,
        num_episodes=NUM_EPISODES,
        max_steps=MAX_STEPS,
        plot_flag=False,
    )
    evaluate_agent(agent, environment, 1, MAX_STEPS, render=False)
    agent.logger.delete_directory()  # Cleanup directory.

    torch.testing.assert_allclose(
        critic.table.shape,
        torch.Size([environment.num_actions, environment.num_states]),
    )
예제 #8
0
 def test_append_error(self):
     memory = ExperienceReplay(max_len=100)
     with pytest.raises(TypeError):
         memory.append((1, 2, 3, 4, 5))
예제 #9
0
policy = EpsGreedy(q_function, ExponentialDecay(EPS_START, EPS_END, EPS_DECAY))

optimizer = torch.optim.Adam(q_function.parameters(),
                             lr=LEARNING_RATE,
                             weight_decay=WEIGHT_DECAY)
criterion = torch.nn.MSELoss

if MEMORY == "PER":
    memory = PrioritizedExperienceReplay(max_len=MEMORY_MAX_SIZE,
                                         beta=LinearGrowth(0.8, 1.0, 0.001))
elif MEMORY == "EXP3":
    memory = EXP3ExperienceReplay(max_len=MEMORY_MAX_SIZE,
                                  alpha=0.001,
                                  beta=0.1)
elif MEMORY == "ER":
    memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE, num_steps=0)
else:
    raise NotImplementedError(f"{MEMORY} not implemented.")

agent = DDQNAgent(
    critic=q_function,
    policy=policy,
    criterion=criterion,
    optimizer=optimizer,
    memory=memory,
    num_iter=1,
    train_frequency=1,
    batch_size=BATCH_SIZE,
    target_update_frequency=TARGET_UPDATE_FREQUENCY,
    gamma=GAMMA,
    clip_gradient_val=1.0,
예제 #10
0
    )
    policy = Policy(q_function, ExponentialDecay(start=1.0,
                                                 end=0.01,
                                                 decay=500))
    q_target = NNQFunction(
        dim_state=environment.dim_observation,
        dim_action=environment.dim_action,
        num_states=environment.num_states,
        num_actions=environment.num_actions,
        layers=LAYERS,
        tau=TARGET_UPDATE_TAU,
    )

    optimizer = torch.optim.Adam(q_function.parameters(), lr=LEARNING_RATE)
    criterion = torch.nn.MSELoss
    memory = ExperienceReplay(max_len=MEMORY_MAX_SIZE)

    agent = DDQNAgent(
        critic=q_function,
        policy=policy,
        criterion=criterion,
        optimizer=optimizer,
        memory=memory,
        batch_size=BATCH_SIZE,
        target_update_frequency=TARGET_UPDATE_FREQUENCY,
        gamma=GAMMA,
    )
    rollout_agent(environment,
                  agent,
                  num_episodes=NUM_EPISODES,
                  max_steps=MAX_STEPS)