def nstepdqn_pixel_atari(game, tag=""): config = NStepDQNConfig() config.num_workers = 16 config.task_fn = lambda: Task( game, num_envs=config.num_workers, single_process=False) config.eval_env = Task(game) config.state_normalizer = ImageNormalizer() config.reward_normalizer = SignNormalizer() config.optimizer_fn = lambda params: RMSprop( params, lr=1e-4, alpha=0.99, eps=1e-5) config.network_fn = lambda: VanillaNet(config.action_dim, NatureConvBody()) # config.network_fn = lambda: DuelingNet(config.action_dim, NatureConvBody()) config.random_action_prob = LinearSchedule(1.0, 0.05, 1e6) config.discount = 0.99 config.target_network_update_freq = 10000 config.double_q = True # config.double_q = False config.rollout_length = 5 config.gradient_clip = 5 config.max_steps = int(2e7) # config.eval_interval = int(1e4) # config.eval_episodes = 10 NStepDQNAgent(config).run_steps( tag=f'{tag}{nstepdqn_pixel_atari.__name__}-{game}')
def categorical_dqn_pixel_atari(game, tag=""): config = CategoricalDQNConfig() config.history_length = 4 config.task_fn = lambda: Task(game) config.eval_env = Task(game) config.optimizer_fn = lambda params: Adam( params, lr=0.00025, eps=0.01 / 32) config.network_fn = lambda: CategoricalNet( config.action_dim, config.categorical_n_atoms, NatureConvBody()) config.batch_size = 32 config.replay_fn = lambda: ReplayBuffer( config.eval_env, memory_size=int(1e6), stack=config.history_length) config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6) config.state_normalizer = ImageNormalizer() config.reward_normalizer = SignNormalizer() config.discount = 0.99 config.target_network_update_freq = 10000 config.exploration_steps = 50000 config.categorical_v_max = 10 config.categorical_v_min = -10 config.categorical_n_atoms = 51 config.rollout_length = 4 config.gradient_clip = 0.5 config.max_steps = 2e7 CategoricalDQNAgent(config).run_steps( tag=f'{tag}{categorical_dqn_pixel_atari.__name__}-{game}')
def dqn_pixel_atari(game, tag=""): config = DQNConfig() config.history_length = 4 config.task_fn = lambda: Task(game) config.eval_env = Task(game) config.optimizer_fn = lambda params: RMSprop(params, lr=0.00025, alpha=0.95, eps=0.01, centered=True) # config.network_fn = lambda: VanillaNet(config.action_dim, NatureConvBody()) config.network_fn = lambda: DuelingNet(config.action_dim, NatureConvBody()) config.batch_size = 32 config.replay_fn = lambda: ReplayBuffer(config.eval_env, memory_size=int(1e6), stack=config.history_length) config.state_normalizer = ImageNormalizer() config.reward_normalizer = SignNormalizer() config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6) config.discount = 0.99 config.target_network_update_freq = 10000 config.exploration_steps = 10000 config.double_q = True # config.double_q = False config.rollout_length = 4 config.gradient_clip = 5 config.max_steps = 2e7 DQNAgent(config).run_steps(tag=f'{tag}{dqn_pixel_atari.__name__}-{game}')
def ppo_pixel_atari(game, tag=""): config = PPOConfig() config.history_length = 4 config.num_workers = 16 config.task_fn = lambda: Task(game, num_envs=config.num_workers, single_process=False, history_length=config.history_length) config.eval_env = Task(game, episode_life=False, history_length=config.history_length) config.optimizer_fn = lambda params: RMSprop( params, lr=0.00025, alpha=0.99, eps=1e-5) config.network_fn = lambda: CategoricalActorCriticNet( config.state_dim, config.action_dim, NatureConvBody(in_channels=config.history_length)) config.state_normalizer = ImageNormalizer() config.reward_normalizer = SignNormalizer() config.discount = 0.99 config.use_gae = True config.gae_tau = 0.95 config.entropy_weight = 0.01 config.rollout_length = 128 config.gradient_clip = 0.5 config.optimization_epochs = 3 config.mini_batch_size = 32 * 8 config.ppo_ratio_clip = 0.1 config.log_interval = 128 * 8 config.max_steps = int(2e7) PPOAgent(config).run_steps(tag=f'{tag}{ppo_pixel_atari.__name__}-{game}')
def a2c_pixel_atari(game, tag=""): config = A2CConfig() config.num_workers = 16 config.task_fn = lambda: Task( game, num_envs=config.num_workers, single_process=False) config.eval_env = Task(game, episode_life=False) config.optimizer_fn = lambda params: RMSprop( params, lr=1e-4, alpha=0.99, eps=1e-5) config.network_fn = lambda: CategoricalActorCriticNet( config.state_dim, config.action_dim, NatureConvBody()) config.state_normalizer = ImageNormalizer() config.reward_normalizer = SignNormalizer() config.discount = 0.99 config.use_gae = True config.gae_tau = 1.0 config.entropy_weight = 0.01 config.rollout_length = 5 config.gradient_clip = 5 config.max_steps = int(2e7) A2CAgent(config).run_steps(tag=f'{tag}{a2c_pixel_atari.__name__}-{game}')
def option_critic_pixel_atari(game, tag=""): config = OptionCriticConfig() config.history_length = 4 config.num_workers = 16 config.task_fn = lambda: Task(game, num_envs=config.num_workers, single_process=False, history_length=config.history_length) config.eval_env = Task(game, episode_life=False, history_length=config.history_length) config.optimizer_fn = lambda params: RMSprop(params, lr=1e-4, alpha=0.99, eps=1e-5) config.network_fn = lambda: OptionCriticNet(NatureConvBody(in_channels=config.history_length), config.action_dim, num_options=4) config.random_option_prob = LinearSchedule(0.1) config.state_normalizer = ImageNormalizer() config.reward_normalizer = SignNormalizer() config.discount = 0.99 config.target_network_update_freq = 10000 config.rollout_length = 5 config.termination_regularizer = 0.01 config.entropy_weight = 0.01 config.gradient_clip = 5 config.max_steps = int(2e7) OptionCriticAgent(config).run_steps(tag=f'{tag}{option_critic_pixel_atari.__name__}-{game}')
h = frame * scale * h_max screen = display.set_mode((w, h)) env = make_atari('BreakoutNoFrameskip-v4') env = wrap_deepmind(env) replay = ReplayBuffer(env, memory_size=int(1e5), batch_size=batch_size, stack=k) seed = 0 env.seed(seed) np.random.seed(seed) normalizer = ImageNormalizer() state = normalizer(env.reset()) while replay.size < replay.memory_size: action = env.action_space.sample() next_state, reward, done, info = env.step(action) if done: next_state = env.reset() next_state = normalizer(next_state) replay.store([state, action, reward, next_state, done]) state = next_state print(replay.size) while running: # clock.tick(FPS) for event in pygame.event.get(): if event.type == pygame.QUIT: running = False