def categorical_dqn_pixel_atari(game, tag=""): config = CategoricalDQNConfig() config.history_length = 4 config.task_fn = lambda: Task(game) config.eval_env = Task(game) config.optimizer_fn = lambda params: Adam( params, lr=0.00025, eps=0.01 / 32) config.network_fn = lambda: CategoricalNet( config.action_dim, config.categorical_n_atoms, NatureConvBody()) config.batch_size = 32 config.replay_fn = lambda: ReplayBuffer( config.eval_env, memory_size=int(1e6), stack=config.history_length) config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6) config.state_normalizer = ImageNormalizer() config.reward_normalizer = SignNormalizer() config.discount = 0.99 config.target_network_update_freq = 10000 config.exploration_steps = 50000 config.categorical_v_max = 10 config.categorical_v_min = -10 config.categorical_n_atoms = 51 config.rollout_length = 4 config.gradient_clip = 0.5 config.max_steps = 2e7 CategoricalDQNAgent(config).run_steps( tag=f'{tag}{categorical_dqn_pixel_atari.__name__}-{game}')
def categorical_dqn_cart_pole(): game = 'CartPole-v0' config = CategoricalDQNConfig() config.task_fn = lambda: Task(game) config.eval_env = Task(game) config.optimizer_fn = lambda params: RMSprop(params, 0.001) config.network_fn = lambda: CategoricalNet(config.action_dim, config. categorical_n_atoms, FCBody(config.state_dim)) config.batch_size = 10 config.replay_fn = lambda: ReplayBuffer(config.eval_env, memory_size=int(1e4)) config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4) config.discount = 0.99 config.target_network_update_freq = 200 config.exploration_steps = 100 config.categorical_v_max = 100 config.categorical_v_min = -100 config.categorical_n_atoms = 50 config.rollout_length = 4 config.gradient_clip = 5 config.max_steps = 1e6 CategoricalDQNAgent(config).run_steps( tag=f'{categorical_dqn_cart_pole.__name__}-{game}')
def dqn_pixel_atari(game, tag=""): config = DQNConfig() config.history_length = 4 config.task_fn = lambda: Task(game) config.eval_env = Task(game) config.optimizer_fn = lambda params: RMSprop(params, lr=0.00025, alpha=0.95, eps=0.01, centered=True) # config.network_fn = lambda: VanillaNet(config.action_dim, NatureConvBody()) config.network_fn = lambda: DuelingNet(config.action_dim, NatureConvBody()) config.batch_size = 32 config.replay_fn = lambda: ReplayBuffer(config.eval_env, memory_size=int(1e6), stack=config.history_length) config.state_normalizer = ImageNormalizer() config.reward_normalizer = SignNormalizer() config.random_action_prob = LinearSchedule(1.0, 0.01, 1e6) config.discount = 0.99 config.target_network_update_freq = 10000 config.exploration_steps = 10000 config.double_q = True # config.double_q = False config.rollout_length = 4 config.gradient_clip = 5 config.max_steps = 2e7 DQNAgent(config).run_steps(tag=f'{tag}{dqn_pixel_atari.__name__}-{game}')
def ddpg_continuous(game, tag=""): config = DDPGConfig() config.task_fn = lambda: Task(game) config.eval_env = Task(game) config.network_fn = lambda: DeterministicActorCriticNet( config.state_dim, config.action_dim, actor_body=FCBody(config.state_dim, (400, 300), gate=F.relu), critic_body=TwoLayerFCBodyWithAction( config.state_dim, config.action_dim, (400, 300), gate=F.relu), actor_opt_fn=lambda params: Adam(params, lr=1e-4), critic_opt_fn=lambda params: Adam(params, lr=1e-3)) config.batch_size = 64 config.replay_fn = lambda: ReplayBuffer(config.eval_env, memory_size=int(1e6)) config.random_process_fn = lambda: OrnsteinUhlenbeckProcess( size=(config.action_dim, ), std=LinearSchedule(0.2)) config.discount = 0.99 config.min_memory_size = 64 config.target_network_mix = 1e-3 config.max_steps = int(1e6) DDPGAgent(config).run_steps(tag=f'{tag}{ddpg_continuous.__name__}-{game}')
def dqn_cart_pole(): game = 'CartPole-v0' config = DQNConfig() config.task_fn = lambda: Task(game) config.eval_env = Task(game) config.optimizer_fn = lambda params: RMSprop(params, 0.001) config.network_fn = lambda: VanillaNet(config.action_dim, FCBody(config.state_dim)) # config.network_fn = lambda: DuelingNet(config.action_dim, FCBody(config.state_dim)) config.batch_size = 10 config.replay_fn = lambda: ReplayBuffer(config.eval_env, memory_size=int(1e4)) config.random_action_prob = LinearSchedule(1.0, 0.1, 1e4) config.discount = 0.99 config.target_network_update_freq = 200 config.exploration_steps = 1000 config.double_q = True config.rollout_length = 4 config.gradient_clip = 5 config.eval_interval = int(5e3) config.max_steps = 1e6 DQNAgent(config).run_steps(tag=f'{dqn_cart_pole.__name__}-{game}')
running = True frame = 84 k = 4 batch_size = 32 scale = 1 h_max = 8 w = frame * k * scale * (batch_size // h_max) h = frame * scale * h_max screen = display.set_mode((w, h)) env = make_atari('BreakoutNoFrameskip-v4') env = wrap_deepmind(env) replay = ReplayBuffer(env, memory_size=int(1e5), batch_size=batch_size, stack=k) seed = 0 env.seed(seed) np.random.seed(seed) normalizer = ImageNormalizer() state = normalizer(env.reset()) while replay.size < replay.memory_size: action = env.action_space.sample() next_state, reward, done, info = env.step(action) if done: next_state = env.reset() next_state = normalizer(next_state) replay.store([state, action, reward, next_state, done]) state = next_state print(replay.size)