def test_fit_batch_online_atari_with_dqn(): import d4rl_atari make_env = lambda: ChannelFirst(DummyAtari()) env = AsyncBatchEnv([make_env for _ in range(2)]) eval_env = ChannelFirst(DummyAtari()) algo = DQN(n_frames=4) buffer = BatchReplayBuffer(1000, env) explorer = LinearDecayEpsilonGreedy() algo.fit_batch_online( env, buffer, explorer, n_epochs=1, n_steps_per_epoch=500, n_updates_per_epoch=1, eval_env=eval_env, logdir="test_data", ) assert algo.impl.observation_shape == (4, 84, 84)
def test_fit_batch_online_atari_with_dqn(): import d4rl_atari make_env = lambda: gym.make("breakout-mixed-v0", stack=False) env = AsyncBatchEnv([make_env for _ in range(2)]) eval_env = gym.make("breakout-mixed-v0", stack=False) algo = DQN(n_frames=4) buffer = BatchReplayBuffer(1000, env) explorer = LinearDecayEpsilonGreedy() algo.fit_batch_online( env, buffer, explorer, n_epochs=1, n_steps_per_epoch=500, n_updates_per_epoch=1, eval_env=eval_env, logdir="test_data", tensorboard=False, ) assert algo.impl.observation_shape == (4, 84, 84)
def test_async_batch_env_continuous(n_envs, n_steps): make_env_fn = lambda: gym.make("Pendulum-v0") env = AsyncBatchEnv([make_env_fn for _ in range(n_envs)]) observation_shape = env.observation_space.shape action_size = env.action_space.shape[0] observations = env.reset() assert observations.shape == (n_envs, ) + observation_shape for _ in range(n_steps): actions = np.random.random((n_envs, action_size)) observations, rewards, terminals, infos = env.step(actions) assert observations.shape == (n_envs, ) + observation_shape assert rewards.shape == (n_envs, ) assert terminals.shape == (n_envs, ) assert len(infos) == n_envs
def test_async_batch_env_discrete(n_envs, n_steps): make_env_fn = lambda: gym.make("CartPole-v0") env = AsyncBatchEnv([make_env_fn for _ in range(n_envs)]) observation_shape = env.observation_space.shape action_size = env.action_space.n observations = env.reset() assert observations.shape == (n_envs, ) + observation_shape for _ in range(n_steps): actions = np.random.randint(action_size, size=n_envs) observations, rewards, terminals, infos = env.step(actions) assert observations.shape == (n_envs, ) + observation_shape assert rewards.shape == (n_envs, ) assert terminals.shape == (n_envs, ) assert len(infos) == n_envs
def test_fit_batch_online_pendulum_with_sac(): make_env = lambda: gym.make("Pendulum-v0") env = AsyncBatchEnv([make_env for _ in range(5)]) eval_env = gym.make("Pendulum-v0") algo = SAC() buffer = BatchReplayBuffer(1000, env) algo.fit_batch_online( env, buffer, n_epochs=1, n_steps_per_epoch=500, n_updates_per_epoch=1, eval_env=eval_env, logdir="test_data", )
def test_fit_batch_online_cartpole_with_dqn(): make_env = lambda: gym.make("CartPole-v0") env = AsyncBatchEnv([make_env for _ in range(5)]) eval_env = gym.make("CartPole-v0") algo = DQN() buffer = BatchReplayBuffer(1000, env) explorer = LinearDecayEpsilonGreedy() algo.fit_batch_online( env, buffer, explorer, n_epochs=1, n_steps_per_epoch=500, n_updates_per_epoch=1, eval_env=eval_env, logdir="test_data", )
import gym from d3rlpy.algos import DQN from d3rlpy.envs import AsyncBatchEnv from d3rlpy.online.buffers import BatchReplayBuffer from d3rlpy.online.explorers import LinearDecayEpsilonGreedy if __name__ == '__main__': env = AsyncBatchEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]) eval_env = gym.make('CartPole-v0') # setup algorithm dqn = DQN(batch_size=32, learning_rate=1e-3, target_update_interval=1000, use_gpu=False) # replay buffer for experience replay buffer = BatchReplayBuffer(maxlen=100000, env=env) # epilon-greedy explorer explorer = LinearDecayEpsilonGreedy(start_epsilon=1.0, end_epsilon=0.1, duration=100000) # start training dqn.fit_batch_online(env, buffer, explorer, n_epochs=100, eval_interval=1,
import gym from d3rlpy.algos import SAC from d3rlpy.envs import AsyncBatchEnv from d3rlpy.online.buffers import BatchReplayBuffer if __name__ == '__main__': env = AsyncBatchEnv([lambda: gym.make('Pendulum-v0') for _ in range(10)]) eval_env = gym.make('Pendulum-v0') # setup algorithm sac = SAC(batch_size=100, use_gpu=False) # replay buffer for experience replay buffer = BatchReplayBuffer(maxlen=100000, env=env) # start training sac.fit_batch_online(env, buffer, n_epochs=100, eval_interval=1, eval_env=eval_env, n_steps_per_epoch=1000, n_updates_per_epoch=1000)