示例#1
0
def main(args):
    dataset, env = get_atari(args.dataset)

    d3rlpy.seed(args.seed)

    train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

    bc = DiscreteBC(
        n_frames=4,  # frame stacking
        scaler='pixel',
        use_gpu=args.gpu)

    bc.fit(train_episodes,
           eval_episodes=test_episodes,
           n_epochs=100,
           scorers={'environment': evaluate_on_environment(env, epsilon=0.05)})
示例#2
0
def main(args):
    dataset, env = get_atari(args.dataset)

    d3rlpy.seed(args.seed)

    train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

    device = None if args.gpu is None else Device(args.gpu)

    bc = DiscreteBC(n_epochs=100,
                    scaler='pixel',
                    use_batch_norm=False,
                    use_gpu=device)

    bc.fit(train_episodes,
           eval_episodes=test_episodes,
           scorers={'environment': evaluate_on_environment(env, epsilon=0.05)})
示例#3
0
def main(args):
    dataset, env = get_atari(args.dataset)

    d3rlpy.seed(args.seed)

    train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

    dqn = DQN(
        n_frames=4,  # frame stacking
        q_func_type=args.q_func_type,
        scaler='pixel',
        use_gpu=args.gpu)

    dqn.fit(train_episodes,
            eval_episodes=test_episodes,
            n_epochs=100,
            scorers={
                'environment': evaluate_on_environment(env, epsilon=0.05),
                'td_error': td_error_scorer,
                'discounted_advantage': discounted_sum_of_advantage_scorer,
                'value_scale': average_value_estimation_scorer
            })
示例#4
0
def main(args):
    dataset, env = get_atari(args.dataset)

    d3rlpy.seed(args.seed)

    train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

    device = None if args.gpu is None else Device(args.gpu)

    dqn = DQN(n_epochs=100,
              q_func_type=args.q_func_type,
              scaler='pixel',
              use_batch_norm=False,
              use_gpu=device)

    dqn.fit(train_episodes,
            eval_episodes=test_episodes,
            scorers={
                'environment': evaluate_on_environment(env, epsilon=0.05),
                'td_error': td_error_scorer,
                'discounted_advantage': discounted_sum_of_advantage_scorer,
                'value_scale': average_value_estimation_scorer
            })
示例#5
0
文件: fqe_atari.py 项目: wx-b/d3rlpy
from sklearn.model_selection import train_test_split
from d3rlpy.datasets import get_atari
from d3rlpy.algos import DiscreteCQL
from d3rlpy.ope import DiscreteFQE
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import initial_state_value_estimation_scorer
from d3rlpy.metrics.scorer import soft_opc_scorer

dataset, env = get_atari('breakout-expert-v0')

train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

# train algorithm
cql = DiscreteCQL(n_epochs=100,
                  scaler='pixel',
                  q_func_factory='qr',
                  n_frames=4,
                  use_gpu=True)
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        scorers={
            'environment': evaluate_on_environment(env, epsilon=0.05),
            'init_value': initial_state_value_estimation_scorer,
            'soft_opc': soft_opc_scorer(70)
        })

# or load the trained model
# cql = DiscreteCQL.from_json('<path-to-json>/params.json')
# cql.load_model('<path-to-model>/model.pt')

# evaluate the trained policy
示例#6
0
import copy

from d3rlpy.algos import DoubleDQN
from d3rlpy.datasets import get_atari
from d3rlpy.online.buffers import ReplayBuffer
from d3rlpy.online.explorers import LinearDecayEpsilonGreedy

# get wrapped atari environment
_, env = get_atari('breakout-mixed-v0')
eval_env = copy.deepcopy(env)

# setup algorithm
dqn = DoubleDQN(batch_size=32,
                learning_rate=2.5e-4,
                target_update_interval=10000 / 4,
                q_func_type='qr',
                scaler='pixel',
                n_frames=4,
                use_gpu=True)

# replay buffer for experience replay
buffer = ReplayBuffer(maxlen=1000000, env=env)

# epilon-greedy explorer
explorer = LinearDecayEpsilonGreedy(start_epsilon=1.0,
                                    end_epsilon=0.1,
                                    duration=1000000)

# start training
dqn.fit_online(env,
               buffer,
示例#7
0
import numpy as np

from d3rlpy.datasets import get_atari
from d3rlpy.dataset import MDPDataset
from minerva.dataset import export_mdp_dataset_as_csv

# prepare MDPDataset
dataset, _ = get_atari('breakout-mixed-v0')

# take 100 episodes due to dataset size
episodes = dataset.episodes[:30]

observations = []
actions = []
rewards = []
terminals = []

for episode in episodes:
    observations.append(episode.observations)
    actions.append(episode.actions.reshape(-1))
    rewards.append(episode.rewards.reshape(-1))
    flag = np.zeros(episode.observations.shape[0])
    flag[-1] = 1.0
    terminals.append(flag)

observations = np.vstack(observations)
actions = np.hstack(actions)
rewards = np.hstack(rewards)
terminals = np.hstack(terminals)

dataset = MDPDataset(observations=observations,
示例#8
0
from d3rlpy.algos import DiscreteCQL
from d3rlpy.models.optimizers import AdamFactory
from d3rlpy.datasets import get_atari
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from sklearn.model_selection import train_test_split

dataset, env = get_atari('breakout-medium-v0')

_, test_episodes = train_test_split(dataset, test_size=0.2)

cql = DiscreteCQL(optim_factory=AdamFactory(eps=1e-2 / 32),
                  scaler='pixel',
                  n_frames=4,
                  q_func_factory='qr',
                  use_gpu=True)

cql.fit(dataset.episodes,
        eval_episodes=test_episodes,
        n_epochs=2000,
        scorers={
            'environment': evaluate_on_environment(env, epsilon=0.001),
            'value_scale': average_value_estimation_scorer
        })