def setup(): set_seed(24) n_epochs = 11 steps_per_epoch = 10 sampler_batch_size = 512 num_timesteps = 100 * steps_per_epoch * sampler_batch_size env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e6)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5)) policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.01, decay_ratio=0.4) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=exploration_policy, replay_buffer=replay_buffer, steps_per_epoch=steps_per_epoch, qf_lr=5e-5, discount=0.9, min_buffer_size=int(1e4), n_train_steps=500, target_update_freq=30, buffer_batch_size=64) return algo, env, replay_buffer, n_epochs, sampler_batch_size
def dqn_cartpole(ctxt=None, seed=24): """Train DQN with CartPole-v0 environment. Args: ctxt (garage.experiment.ExperimentContext): The experiment configuration used by LocalRunner to create the snapshotter. seed (int): Used to seed the random number generator to produce determinism. """ set_seed(seed) runner = Trainer(ctxt) n_epochs = 100 steps_per_epoch = 10 sampler_batch_size = 512 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e6)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5)) policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.01, decay_ratio=0.4) sampler = LocalSampler(agents=exploration_policy, envs=env, max_episode_length=env.spec.max_episode_length, worker_class=FragmentWorker) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=exploration_policy, replay_buffer=replay_buffer, sampler=sampler, steps_per_epoch=steps_per_epoch, qf_lr=5e-5, discount=0.9, min_buffer_size=int(1e4), n_train_steps=500, target_update_freq=30, buffer_batch_size=64) runner.setup(algo, env) runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size) env.close()
def dqn_atari(ctxt=None, env=None, seed=24, n_workers=psutil.cpu_count(logical=False), max_episode_length=None, **kwargs): """Train DQN with PongNoFrameskip-v4 environment. Args: ctxt (garage.experiment.ExperimentContext): The experiment configuration used by Trainer to create the snapshotter. env (str): Name of the atari environment, eg. 'PongNoFrameskip-v4'. seed (int): Used to seed the random number generator to produce determinism. n_workers (int): Number of workers to use. Defaults to the number of CPU cores available. max_episode_length (int): Max length of an episode. If None, defaults to the timelimit specific to the environment. Used by integration tests. kwargs (dict): hyperparameters to be saved to variant.json. """ assert n_workers > 0 assert env is not None env = gym.make(env) env = Noop(env, noop_max=30) env = MaxAndSkip(env, skip=4) env = EpisodicLife(env) if 'FIRE' in env.unwrapped.get_action_meanings(): env = FireReset(env) env = Grayscale(env) env = Resize(env, 84, 84) env = ClipReward(env) env = StackFrames(env, 4, axis=0) env = GymEnv(env, max_episode_length=max_episode_length, is_image=True) set_seed(seed) trainer = Trainer(ctxt) n_epochs = hyperparams['n_epochs'] steps_per_epoch = hyperparams['steps_per_epoch'] sampler_batch_size = hyperparams['sampler_batch_size'] num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size replay_buffer = PathBuffer( capacity_in_transitions=hyperparams['buffer_size']) qf = DiscreteCNNQFunction( env_spec=env.spec, image_format='NCHW', hidden_channels=hyperparams['hidden_channels'], kernel_sizes=hyperparams['kernel_sizes'], strides=hyperparams['strides'], hidden_w_init=( lambda x: torch.nn.init.orthogonal_(x, gain=np.sqrt(2))), hidden_sizes=hyperparams['hidden_sizes']) policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy( env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=hyperparams['max_epsilon'], min_epsilon=hyperparams['min_epsilon'], decay_ratio=hyperparams['decay_ratio']) sampler = LocalSampler(agents=exploration_policy, envs=env, max_episode_length=env.spec.max_episode_length, worker_class=FragmentWorker, n_workers=n_workers) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=exploration_policy, replay_buffer=replay_buffer, sampler=sampler, steps_per_epoch=steps_per_epoch, qf_lr=hyperparams['lr'], clip_gradient=hyperparams['clip_gradient'], discount=hyperparams['discount'], min_buffer_size=hyperparams['min_buffer_size'], n_train_steps=hyperparams['n_train_steps'], target_update_freq=hyperparams['target_update_freq'], buffer_batch_size=hyperparams['buffer_batch_size']) set_gpu_mode(False) torch.set_num_threads(1) if torch.cuda.is_available(): set_gpu_mode(True) algo.to() trainer.setup(algo, env) trainer.train(n_epochs=n_epochs, batch_size=sampler_batch_size) env.close()