예제 #1
0
def categorical_dqn_pixel_atari(game, tag=""):
    config = CategoricalDQNConfig()
    config.history_length = 4

    config.task_fn = lambda: Task(game)
    config.eval_env = Task(game)

    config.optimizer_fn = lambda params: Adam(
        params, lr=0.00025, eps=0.01 / 32)
    config.network_fn = lambda: CategoricalNet(
        config.action_dim, config.categorical_n_atoms, NatureConvBody())

    config.batch_size = 32
    config.replay_fn = lambda: ReplayBuffer(
        config.eval_env, memory_size=int(1e6), stack=config.history_length)
    config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6)

    config.state_normalizer = ImageNormalizer()
    config.reward_normalizer = SignNormalizer()

    config.discount = 0.99
    config.target_network_update_freq = 10000
    config.exploration_steps = 50000
    config.categorical_v_max = 10
    config.categorical_v_min = -10
    config.categorical_n_atoms = 51
    config.rollout_length = 4
    config.gradient_clip = 0.5
    config.max_steps = 2e7
    CategoricalDQNAgent(config).run_steps(
        tag=f'{tag}{categorical_dqn_pixel_atari.__name__}-{game}')
예제 #2
0
def nstepdqn_pixel_atari(game, tag=""):
    config = NStepDQNConfig()
    config.num_workers = 16
    config.task_fn = lambda: Task(
        game, num_envs=config.num_workers, single_process=False)
    config.eval_env = Task(game)

    config.state_normalizer = ImageNormalizer()
    config.reward_normalizer = SignNormalizer()

    config.optimizer_fn = lambda params: RMSprop(
        params, lr=1e-4, alpha=0.99, eps=1e-5)
    config.network_fn = lambda: VanillaNet(config.action_dim, NatureConvBody())
    # config.network_fn = lambda: DuelingNet(config.action_dim, NatureConvBody())

    config.random_action_prob = LinearSchedule(1.0, 0.05, 1e6)
    config.discount = 0.99

    config.target_network_update_freq = 10000
    config.double_q = True
    # config.double_q = False
    config.rollout_length = 5
    config.gradient_clip = 5
    config.max_steps = int(2e7)

    # config.eval_interval = int(1e4)
    # config.eval_episodes = 10
    NStepDQNAgent(config).run_steps(
        tag=f'{tag}{nstepdqn_pixel_atari.__name__}-{game}')
예제 #3
0
def dqn_pixel_atari(game, tag=""):
    config = DQNConfig()
    config.history_length = 4
    
    config.task_fn = lambda: Task(game)
    config.eval_env = Task(game)
    
    config.optimizer_fn = lambda params: RMSprop(params, lr=0.00025, alpha=0.95, eps=0.01, centered=True)
    # config.network_fn = lambda: VanillaNet(config.action_dim, NatureConvBody())
    config.network_fn = lambda: DuelingNet(config.action_dim, NatureConvBody())
    
    config.batch_size = 32
    config.replay_fn = lambda: ReplayBuffer(config.eval_env, memory_size=int(1e6), stack=config.history_length)
    
    config.state_normalizer = ImageNormalizer()
    config.reward_normalizer = SignNormalizer()
    
    config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6)
    config.discount = 0.99
    config.target_network_update_freq = 10000
    config.exploration_steps = 10000
    config.double_q = True
    # config.double_q = False
    config.rollout_length = 4
    config.gradient_clip = 5
    config.max_steps = 2e7
    DQNAgent(config).run_steps(tag=f'{tag}{dqn_pixel_atari.__name__}-{game}')
예제 #4
0
def ppo_pixel_atari(game, tag=""):
    config = PPOConfig()
    config.history_length = 4
    config.num_workers = 16
    config.task_fn = lambda: Task(game,
                                  num_envs=config.num_workers,
                                  single_process=False,
                                  history_length=config.history_length)
    config.eval_env = Task(game,
                           episode_life=False,
                           history_length=config.history_length)

    config.optimizer_fn = lambda params: RMSprop(
        params, lr=0.00025, alpha=0.99, eps=1e-5)
    config.network_fn = lambda: CategoricalActorCriticNet(
        config.state_dim, config.action_dim,
        NatureConvBody(in_channels=config.history_length))

    config.state_normalizer = ImageNormalizer()
    config.reward_normalizer = SignNormalizer()

    config.discount = 0.99
    config.use_gae = True
    config.gae_tau = 0.95
    config.entropy_weight = 0.01
    config.rollout_length = 128
    config.gradient_clip = 0.5
    config.optimization_epochs = 3
    config.mini_batch_size = 32 * 8
    config.ppo_ratio_clip = 0.1
    config.log_interval = 128 * 8
    config.max_steps = int(2e7)
    PPOAgent(config).run_steps(tag=f'{tag}{ppo_pixel_atari.__name__}-{game}')
예제 #5
0
def a2c_pixel_atari(game, tag=""):
    config = A2CConfig()
    config.num_workers = 16
    config.task_fn = lambda: Task(
        game, num_envs=config.num_workers, single_process=False)
    config.eval_env = Task(game, episode_life=False)

    config.optimizer_fn = lambda params: RMSprop(
        params, lr=1e-4, alpha=0.99, eps=1e-5)
    config.network_fn = lambda: CategoricalActorCriticNet(
        config.state_dim, config.action_dim, NatureConvBody())

    config.state_normalizer = ImageNormalizer()
    config.reward_normalizer = SignNormalizer()

    config.discount = 0.99
    config.use_gae = True
    config.gae_tau = 1.0
    config.entropy_weight = 0.01
    config.rollout_length = 5
    config.gradient_clip = 5
    config.max_steps = int(2e7)
    A2CAgent(config).run_steps(tag=f'{tag}{a2c_pixel_atari.__name__}-{game}')
예제 #6
0
def option_critic_pixel_atari(game, tag=""):
    config = OptionCriticConfig()
    config.history_length = 4
    config.num_workers = 16
    config.task_fn = lambda: Task(game, num_envs=config.num_workers, single_process=False,
                                  history_length=config.history_length)
    config.eval_env = Task(game, episode_life=False, history_length=config.history_length)
    
    config.optimizer_fn = lambda params: RMSprop(params, lr=1e-4, alpha=0.99, eps=1e-5)
    config.network_fn = lambda: OptionCriticNet(NatureConvBody(in_channels=config.history_length),
                                                config.action_dim, num_options=4)
    config.random_option_prob = LinearSchedule(0.1)
    
    config.state_normalizer = ImageNormalizer()
    config.reward_normalizer = SignNormalizer()
    
    config.discount = 0.99
    config.target_network_update_freq = 10000
    config.rollout_length = 5
    config.termination_regularizer = 0.01
    config.entropy_weight = 0.01
    config.gradient_clip = 5
    config.max_steps = int(2e7)
    OptionCriticAgent(config).run_steps(tag=f'{tag}{option_critic_pixel_atari.__name__}-{game}')