Exemple #1
0
def test_fit_batch_online_atari_with_dqn():
    import d4rl_atari

    make_env = lambda: ChannelFirst(DummyAtari())
    env = AsyncBatchEnv([make_env for _ in range(2)])
    eval_env = ChannelFirst(DummyAtari())

    algo = DQN(n_frames=4)

    buffer = BatchReplayBuffer(1000, env)

    explorer = LinearDecayEpsilonGreedy()

    algo.fit_batch_online(
        env,
        buffer,
        explorer,
        n_epochs=1,
        n_steps_per_epoch=500,
        n_updates_per_epoch=1,
        eval_env=eval_env,
        logdir="test_data",
    )

    assert algo.impl.observation_shape == (4, 84, 84)
Exemple #2
0
def test_fit_batch_online_atari_with_dqn():
    import d4rl_atari

    make_env = lambda: gym.make("breakout-mixed-v0", stack=False)
    env = AsyncBatchEnv([make_env for _ in range(2)])
    eval_env = gym.make("breakout-mixed-v0", stack=False)

    algo = DQN(n_frames=4)

    buffer = BatchReplayBuffer(1000, env)

    explorer = LinearDecayEpsilonGreedy()

    algo.fit_batch_online(
        env,
        buffer,
        explorer,
        n_epochs=1,
        n_steps_per_epoch=500,
        n_updates_per_epoch=1,
        eval_env=eval_env,
        logdir="test_data",
        tensorboard=False,
    )

    assert algo.impl.observation_shape == (4, 84, 84)
Exemple #3
0
def test_async_batch_env_continuous(n_envs, n_steps):
    make_env_fn = lambda: gym.make("Pendulum-v0")
    env = AsyncBatchEnv([make_env_fn for _ in range(n_envs)])

    observation_shape = env.observation_space.shape
    action_size = env.action_space.shape[0]

    observations = env.reset()
    assert observations.shape == (n_envs, ) + observation_shape

    for _ in range(n_steps):
        actions = np.random.random((n_envs, action_size))

        observations, rewards, terminals, infos = env.step(actions)

        assert observations.shape == (n_envs, ) + observation_shape
        assert rewards.shape == (n_envs, )
        assert terminals.shape == (n_envs, )
        assert len(infos) == n_envs
Exemple #4
0
def test_async_batch_env_discrete(n_envs, n_steps):
    make_env_fn = lambda: gym.make("CartPole-v0")
    env = AsyncBatchEnv([make_env_fn for _ in range(n_envs)])

    observation_shape = env.observation_space.shape
    action_size = env.action_space.n

    observations = env.reset()
    assert observations.shape == (n_envs, ) + observation_shape

    for _ in range(n_steps):
        actions = np.random.randint(action_size, size=n_envs)

        observations, rewards, terminals, infos = env.step(actions)

        assert observations.shape == (n_envs, ) + observation_shape
        assert rewards.shape == (n_envs, )
        assert terminals.shape == (n_envs, )
        assert len(infos) == n_envs
Exemple #5
0
def test_fit_batch_online_pendulum_with_sac():
    make_env = lambda: gym.make("Pendulum-v0")
    env = AsyncBatchEnv([make_env for _ in range(5)])
    eval_env = gym.make("Pendulum-v0")

    algo = SAC()

    buffer = BatchReplayBuffer(1000, env)

    algo.fit_batch_online(
        env,
        buffer,
        n_epochs=1,
        n_steps_per_epoch=500,
        n_updates_per_epoch=1,
        eval_env=eval_env,
        logdir="test_data",
    )
Exemple #6
0
def test_fit_batch_online_cartpole_with_dqn():
    make_env = lambda: gym.make("CartPole-v0")
    env = AsyncBatchEnv([make_env for _ in range(5)])
    eval_env = gym.make("CartPole-v0")

    algo = DQN()

    buffer = BatchReplayBuffer(1000, env)

    explorer = LinearDecayEpsilonGreedy()

    algo.fit_batch_online(
        env,
        buffer,
        explorer,
        n_epochs=1,
        n_steps_per_epoch=500,
        n_updates_per_epoch=1,
        eval_env=eval_env,
        logdir="test_data",
    )
Exemple #7
0
import gym

from d3rlpy.algos import DQN
from d3rlpy.envs import AsyncBatchEnv
from d3rlpy.online.buffers import BatchReplayBuffer
from d3rlpy.online.explorers import LinearDecayEpsilonGreedy

if __name__ == '__main__':
    env = AsyncBatchEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])
    eval_env = gym.make('CartPole-v0')

    # setup algorithm
    dqn = DQN(batch_size=32,
              learning_rate=1e-3,
              target_update_interval=1000,
              use_gpu=False)

    # replay buffer for experience replay
    buffer = BatchReplayBuffer(maxlen=100000, env=env)

    # epilon-greedy explorer
    explorer = LinearDecayEpsilonGreedy(start_epsilon=1.0,
                                        end_epsilon=0.1,
                                        duration=100000)

    # start training
    dqn.fit_batch_online(env,
                         buffer,
                         explorer,
                         n_epochs=100,
                         eval_interval=1,
Exemple #8
0
import gym

from d3rlpy.algos import SAC
from d3rlpy.envs import AsyncBatchEnv
from d3rlpy.online.buffers import BatchReplayBuffer

if __name__ == '__main__':
    env = AsyncBatchEnv([lambda: gym.make('Pendulum-v0') for _ in range(10)])
    eval_env = gym.make('Pendulum-v0')

    # setup algorithm
    sac = SAC(batch_size=100, use_gpu=False)

    # replay buffer for experience replay
    buffer = BatchReplayBuffer(maxlen=100000, env=env)

    # start training
    sac.fit_batch_online(env,
                         buffer,
                         n_epochs=100,
                         eval_interval=1,
                         eval_env=eval_env,
                         n_steps_per_epoch=1000,
                         n_updates_per_epoch=1000)