예제 #1
0
def categorical_dqn_pixel_atari(game, tag=""):
    config = CategoricalDQNConfig()
    config.history_length = 4

    config.task_fn = lambda: Task(game)
    config.eval_env = Task(game)

    config.optimizer_fn = lambda params: Adam(
        params, lr=0.00025, eps=0.01 / 32)
    config.network_fn = lambda: CategoricalNet(
        config.action_dim, config.categorical_n_atoms, NatureConvBody())

    config.batch_size = 32
    config.replay_fn = lambda: ReplayBuffer(
        config.eval_env, memory_size=int(1e6), stack=config.history_length)
    config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6)

    config.state_normalizer = ImageNormalizer()
    config.reward_normalizer = SignNormalizer()

    config.discount = 0.99
    config.target_network_update_freq = 10000
    config.exploration_steps = 50000
    config.categorical_v_max = 10
    config.categorical_v_min = -10
    config.categorical_n_atoms = 51
    config.rollout_length = 4
    config.gradient_clip = 0.5
    config.max_steps = 2e7
    CategoricalDQNAgent(config).run_steps(
        tag=f'{tag}{categorical_dqn_pixel_atari.__name__}-{game}')
예제 #2
0
def categorical_dqn_cart_pole():
    game = 'CartPole-v0'
    config = CategoricalDQNConfig()
    config.task_fn = lambda: Task(game)
    config.eval_env = Task(game)

    config.optimizer_fn = lambda params: RMSprop(params, 0.001)
    config.network_fn = lambda: CategoricalNet(config.action_dim, config.
                                               categorical_n_atoms,
                                               FCBody(config.state_dim))

    config.batch_size = 10
    config.replay_fn = lambda: ReplayBuffer(config.eval_env,
                                            memory_size=int(1e4))
    config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4)

    config.discount = 0.99
    config.target_network_update_freq = 200
    config.exploration_steps = 100
    config.categorical_v_max = 100
    config.categorical_v_min = -100
    config.categorical_n_atoms = 50
    config.rollout_length = 4
    config.gradient_clip = 5

    config.max_steps = 1e6
    CategoricalDQNAgent(config).run_steps(
        tag=f'{categorical_dqn_cart_pole.__name__}-{game}')
예제 #3
0
def dqn_pixel_atari(game, tag=""):
    config = DQNConfig()
    config.history_length = 4
    
    config.task_fn = lambda: Task(game)
    config.eval_env = Task(game)
    
    config.optimizer_fn = lambda params: RMSprop(params, lr=0.00025, alpha=0.95, eps=0.01, centered=True)
    # config.network_fn = lambda: VanillaNet(config.action_dim, NatureConvBody())
    config.network_fn = lambda: DuelingNet(config.action_dim, NatureConvBody())
    
    config.batch_size = 32
    config.replay_fn = lambda: ReplayBuffer(config.eval_env, memory_size=int(1e6), stack=config.history_length)
    
    config.state_normalizer = ImageNormalizer()
    config.reward_normalizer = SignNormalizer()
    
    config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6)
    config.discount = 0.99
    config.target_network_update_freq = 10000
    config.exploration_steps = 10000
    config.double_q = True
    # config.double_q = False
    config.rollout_length = 4
    config.gradient_clip = 5
    config.max_steps = 2e7
    DQNAgent(config).run_steps(tag=f'{tag}{dqn_pixel_atari.__name__}-{game}')
예제 #4
0
def ddpg_continuous(game, tag=""):
    config = DDPGConfig()
    config.task_fn = lambda: Task(game)
    config.eval_env = Task(game)

    config.network_fn = lambda: DeterministicActorCriticNet(
        config.state_dim,
        config.action_dim,
        actor_body=FCBody(config.state_dim, (400, 300), gate=F.relu),
        critic_body=TwoLayerFCBodyWithAction(
            config.state_dim, config.action_dim, (400, 300), gate=F.relu),
        actor_opt_fn=lambda params: Adam(params, lr=1e-4),
        critic_opt_fn=lambda params: Adam(params, lr=1e-3))

    config.batch_size = 64
    config.replay_fn = lambda: ReplayBuffer(config.eval_env,
                                            memory_size=int(1e6))
    config.random_process_fn = lambda: OrnsteinUhlenbeckProcess(
        size=(config.action_dim, ), std=LinearSchedule(0.2))

    config.discount = 0.99
    config.min_memory_size = 64
    config.target_network_mix = 1e-3
    config.max_steps = int(1e6)
    DDPGAgent(config).run_steps(tag=f'{tag}{ddpg_continuous.__name__}-{game}')
예제 #5
0
def dqn_cart_pole():
    game = 'CartPole-v0'
    config = DQNConfig()
    config.task_fn = lambda: Task(game)
    config.eval_env = Task(game)
    
    config.optimizer_fn = lambda params: RMSprop(params, 0.001)
    config.network_fn = lambda: VanillaNet(config.action_dim, FCBody(config.state_dim))
    # config.network_fn = lambda: DuelingNet(config.action_dim, FCBody(config.state_dim))
    
    config.batch_size = 10
    config.replay_fn = lambda: ReplayBuffer(config.eval_env, memory_size=int(1e4))
    
    config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4)
    config.discount = 0.99
    config.target_network_update_freq = 200
    config.exploration_steps = 1000
    config.double_q = True
    config.rollout_length = 4
    config.gradient_clip = 5
    config.eval_interval = int(5e3)
    config.max_steps = 1e6
    DQNAgent(config).run_steps(tag=f'{dqn_cart_pole.__name__}-{game}')
예제 #6
0
    running = True
    frame = 84
    k = 4
    batch_size = 32
    scale = 1
    h_max = 8
    w = frame * k * scale * (batch_size // h_max)
    h = frame * scale * h_max

    screen = display.set_mode((w, h))

    env = make_atari('BreakoutNoFrameskip-v4')
    env = wrap_deepmind(env)

    replay = ReplayBuffer(env,
                          memory_size=int(1e5),
                          batch_size=batch_size,
                          stack=k)

    seed = 0
    env.seed(seed)
    np.random.seed(seed)
    normalizer = ImageNormalizer()
    state = normalizer(env.reset())
    while replay.size < replay.memory_size:
        action = env.action_space.sample()
        next_state, reward, done, info = env.step(action)
        if done: next_state = env.reset()
        next_state = normalizer(next_state)
        replay.store([state, action, reward, next_state, done])
        state = next_state
        print(replay.size)