コード例 #1
0
def test_collect_pendulum_with_sac():
    env = gym.make("Pendulum-v0")

    algo = SAC()

    buffer = algo.collect(env, n_steps=500)

    assert buffer.size() > 490 and buffer.size() < 500
コード例 #2
0
def test_fit_online_pendulum_with_sac():
    env = gym.make('Pendulum-v0')
    eval_env = gym.make('Pendulum-v0')

    algo = SAC()

    buffer = ReplayBuffer(1000, env)

    algo.fit_online(env,
                    buffer,
                    n_epochs=1,
                    eval_env=eval_env,
                    logdir='test_data',
                    tensorboard=False)
コード例 #3
0
def test_fit_online_pendulum_with_sac():
    env = gym.make("Pendulum-v0")
    eval_env = gym.make("Pendulum-v0")

    algo = SAC()

    buffer = ReplayBuffer(1000, env)

    algo.fit_online(
        env,
        buffer,
        n_steps=500,
        eval_env=eval_env,
        logdir="test_data",
    )
コード例 #4
0
def test_fit_batch_online_pendulum_with_sac():
    make_env = lambda: gym.make("Pendulum-v0")
    env = AsyncBatchEnv([make_env for _ in range(5)])
    eval_env = gym.make("Pendulum-v0")

    algo = SAC()

    buffer = BatchReplayBuffer(1000, env)

    algo.fit_batch_online(
        env,
        buffer,
        n_epochs=1,
        n_steps_per_epoch=500,
        n_updates_per_epoch=1,
        eval_env=eval_env,
        logdir="test_data",
    )
コード例 #5
0
ファイル: test_sb3.py プロジェクト: ritou11/d3rlpy
def test_sb3_wrapper(observation_shape, action_size, batch_size):
    algo = SAC()
    algo.create_impl(observation_shape, action_size)

    sb3 = SB3Wrapper(algo)

    observations = np.random.random((batch_size, ) + observation_shape)

    # check greedy action
    actions, state = sb3.predict(observations, deterministic=True)
    assert actions.shape == (batch_size, action_size)
    assert state is None

    # check sampling
    stochastic_actions, state = sb3.predict(observations, deterministic=False)
    assert stochastic_actions.shape == (batch_size, action_size)
    assert state is None
    assert not np.allclose(actions, stochastic_actions)
コード例 #6
0
ファイル: train_sac.py プロジェクト: kintatta/d3rl
def main(args):
    dataset, env = get_pybullet(args.dataset)

    d3rlpy.seed(args.seed)

    train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

    device = None if args.gpu is None else Device(args.gpu)

    sac = SAC(n_epochs=100, q_func_type=args.q_func_type, use_gpu=device)

    sac.fit(train_episodes,
            eval_episodes=test_episodes,
            scorers={
                'environment': evaluate_on_environment(env),
                'td_error': td_error_scorer,
                'discounted_advantage': discounted_sum_of_advantage_scorer,
                'value_scale': average_value_estimation_scorer,
                'value_std': value_estimation_std_scorer,
                'action_diff': continuous_action_diff_scorer
            })
コード例 #7
0
def test_timelimit_aware(timelimit_aware):
    env = gym.make("Pendulum-v0")

    algo = SAC()

    buffer = ReplayBuffer(1000, env)

    algo.fit_online(
        env,
        buffer,
        n_steps=500,
        logdir="test_data",
        timelimit_aware=timelimit_aware,
    )

    terminal_count = 0
    for i in range(len(buffer)):
        terminal_count += int(buffer.transitions[i].terminal)

    if timelimit_aware:
        assert terminal_count == 0
    else:
        assert terminal_count > 0
コード例 #8
0
ファイル: test_iterators.py プロジェクト: tandakun/d3rlpy
def test_train_with_sac():
    env = gym.make('Pendulum-v0')
    eval_env = gym.make('Pendulum-v0')

    algo = SAC(n_epochs=1)

    buffer = ReplayBuffer(1000, env)

    train(env,
          algo,
          buffer,
          eval_env=eval_env,
          logdir='test_data',
          tensorboard=False)
コード例 #9
0
import gym

from d3rlpy.algos import SAC
from d3rlpy.online.buffers import ReplayBuffer
from d3rlpy.online.iterators import train

env = gym.make('Pendulum-v0')
eval_env = gym.make('Pendulum-v0')

# setup algorithm
sac = SAC(n_epochs=100, batch_size=100, use_gpu=False)

# replay buffer for experience replay
buffer = ReplayBuffer(maxlen=100000, env=env)

# start training
# probablistic policies does not need explorers
train(env,
      sac,
      buffer,
      eval_env=eval_env,
      n_steps_per_epoch=1000,
      n_updates_per_epoch=100)
コード例 #10
0
import gym

from d3rlpy.algos import SAC
from d3rlpy.envs import AsyncBatchEnv
from d3rlpy.online.buffers import BatchReplayBuffer

if __name__ == '__main__':
    env = AsyncBatchEnv([lambda: gym.make('Pendulum-v0') for _ in range(10)])
    eval_env = gym.make('Pendulum-v0')

    # setup algorithm
    sac = SAC(batch_size=100, use_gpu=False)

    # replay buffer for experience replay
    buffer = BatchReplayBuffer(maxlen=100000, env=env)

    # start training
    sac.fit_batch_online(env,
                         buffer,
                         n_epochs=100,
                         eval_interval=1,
                         eval_env=eval_env,
                         n_steps_per_epoch=1000,
                         n_updates_per_epoch=1000)
コード例 #11
0
import gym

from d3rlpy.algos import SAC
from d3rlpy.online.buffers import ReplayBuffer

env = gym.make('Pendulum-v0')
eval_env = gym.make('Pendulum-v0')

# setup algorithm
sac = SAC(batch_size=100, use_gpu=False)

# replay buffer for experience replay
buffer = ReplayBuffer(maxlen=100000, env=env)

# start training
# probablistic policies does not need explorers
sac.fit_online(env,
               buffer,
               n_steps=100000,
               eval_env=eval_env,
               n_steps_per_epoch=1000,
               update_start_step=1000)