示例#1
0
def dqn_with_prioritized_experience(env, n_episodes=None):
    # DQN with e-greedy exploration, prioritized experience replay, and fixed-Q targets
    sched_step = 1.0 / n_episodes if n_episodes is not None else 0.001

    model = build_network(env)
    target_model = build_network(env)
    alpha_sched = LinearSchedule(start=0.0, end=1.0, step=sched_step)
    beta_sched = LinearSchedule(start=0.0, end=1.0, step=sched_step)
    experience = PrioritizedExperienceReplay(maxlen=10000,
                                             sample_batch_size=64,
                                             min_size_to_sample=1000,
                                             alpha_sched=alpha_sched,
                                             beta_sched=beta_sched)
    decay_sched = ExponentialSchedule(start=1.0, end=0.01, step=0.995)
    exploration = EpsilonGreedyExploration(decay_sched=decay_sched)
    fixed_target = FixedQTarget(target_model,
                                target_update_step=500,
                                use_soft_targets=True,
                                use_double_q=True)
    agent = DQNAgent(env,
                     model,
                     gamma=0.99,
                     exploration=exploration,
                     experience=experience,
                     fixed_q_target=fixed_target)

    # Pre-load samples in experience replay.
    # This can also be done implicitly during regular training episodes,
    # but the early training may overfit to early samples.
    experience.bootstrap(env)

    # Perform the training
    return train_dqn(agent, n_episodes)
示例#2
0
def dqn_with_fixed_targets(env, n_episodes=None):
    # DQN with e-greedy exploration, experience replay, and fixed-Q targets
    model = build_network(env)
    target_model = build_network(env)
    experience = ExperienceReplay(maxlen=2000,
                                  sample_batch_size=32,
                                  min_size_to_sample=100)
    decay_sched = ExponentialSchedule(start=1.0, end=0.01, step=0.99)
    exploration = EpsilonGreedyExploration(decay_sched=decay_sched)
    fixed_target = FixedQTarget(target_model,
                                target_update_step=500,
                                use_soft_targets=True)
    agent = DQNAgent(env,
                     model,
                     gamma=0.99,
                     exploration=exploration,
                     experience=experience,
                     fixed_q_target=fixed_target)

    # Pre-load samples in experience replay.
    # This can also be done implicitly during regular training episodes,
    # but the early training may overfit to early samples.
    experience.bootstrap(env)

    # Perform the training
    return train_dqn(agent, n_episodes, debug=n_episodes is None)
示例#3
0
def basic_dqn(env, n_episodes):
    # Basic DQN with e-greedy exploration
    model = build_network(env)
    decay_sched = ExponentialSchedule(start=1.0, end=0.01, step=0.99)
    exploration = EpsilonGreedyExploration(decay_sched=decay_sched)
    agent = DQNAgent(env, model, gamma=0.99, exploration=exploration)

    # Perform the training
    return train_dqn(agent, n_episodes, debug=True)
示例#4
0
def dqn_with_experience(env, n_episodes):
    # DQN with e-greedy exploration and experience replay
    model = build_network(env)
    experience = ExperienceReplay(maxlen=10000,
                                  sample_batch_size=64,
                                  min_size_to_sample=1000)
    decay_sched = ExponentialSchedule(start=1.0, end=0.01, step=0.995)
    exploration = EpsilonGreedyExploration(decay_sched=decay_sched)
    agent = DQNAgent(env,
                     model,
                     gamma=0.99,
                     exploration=exploration,
                     experience=experience)

    # Pre-load samples in experience replay.
    # This can also be done implicitly during regular training episodes,
    # but the early training may overfit to early samples.
    experience.bootstrap(env)

    # Perform the training
    return train_dqn(agent, n_episodes, debug=True)