예제 #1
0
def test_compare():
    dataset, _ = get_cartpole()
    train_episodes = dataset.episodes[:10]
    test_episodes = dataset.episodes[-10:]

    algo = DQN(n_epochs=1)
    algo.fit(train_episodes, logdir='test_data')

    base_algo = DQN(n_epochs=1)
    base_algo.fit(train_episodes, logdir='test_data')

    score = _compare(algo, base_algo, test_episodes, True)
예제 #2
0
def test_evaluate():
    dataset, _ = get_cartpole()
    train_episodes = dataset.episodes[:10]
    test_episodes = dataset.episodes[-10:]

    algo = DQN(n_epochs=1)
    algo.fit(train_episodes, logdir='test_data')

    scores = _evaluate(algo, test_episodes, True)

    eval_keys = [
        'td_error', 'advantage', 'average_value', 'value_std', 'action_match'
    ]

    for key in eval_keys:
        assert key in scores
예제 #3
0
def main(args):
    dataset, env = get_atari(args.dataset)

    d3rlpy.seed(args.seed)

    train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

    dqn = DQN(
        n_frames=4,  # frame stacking
        q_func_type=args.q_func_type,
        scaler='pixel',
        use_gpu=args.gpu)

    dqn.fit(train_episodes,
            eval_episodes=test_episodes,
            n_epochs=100,
            scorers={
                'environment': evaluate_on_environment(env, epsilon=0.05),
                'td_error': td_error_scorer,
                'discounted_advantage': discounted_sum_of_advantage_scorer,
                'value_scale': average_value_estimation_scorer
            })
예제 #4
0
def main(args):
    dataset, env = get_atari(args.dataset)

    d3rlpy.seed(args.seed)

    train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

    device = None if args.gpu is None else Device(args.gpu)

    dqn = DQN(n_epochs=100,
              q_func_type=args.q_func_type,
              scaler='pixel',
              use_batch_norm=False,
              use_gpu=device)

    dqn.fit(train_episodes,
            eval_episodes=test_episodes,
            scorers={
                'environment': evaluate_on_environment(env, epsilon=0.05),
                'td_error': td_error_scorer,
                'discounted_advantage': discounted_sum_of_advantage_scorer,
                'value_scale': average_value_estimation_scorer
            })
예제 #5
0
파일: train_dqn.py 프로젝트: kintatta/d3rl
from d3rlpy.algos import DQN
from d3rlpy.datasets import get_cartpole
from d3rlpy.metrics.scorer import evaluate_on_environment

# obtain dataset
dataset, env = get_cartpole()

# setup algorithm
dqn = DQN(n_epochs=1)

# train
dqn.fit(dataset.episodes)

# evaluate trained algorithm
evaluate_on_environment(env, render=True)(dqn)
예제 #6
0
파일: train_dqn.py 프로젝트: wx-b/d3rlpy
from d3rlpy.algos import DQN
from d3rlpy.datasets import get_cartpole
from d3rlpy.metrics.scorer import evaluate_on_environment

# obtain dataset
dataset, env = get_cartpole()

# setup algorithm
dqn = DQN()

# train
dqn.fit(dataset.episodes, n_epochs=1)

# evaluate trained algorithm
evaluate_on_environment(env, render=True)(dqn)