Beispiel #1
0
import os

import gym

from rl.agents.trpo import TRPO
from rl.loops import EpisodeTrainLoop
from rl.metrics import AverageReturn, AverageEpisodeLength
from zoo.cartpole.core import PolicyNetwork, ValueFunctionNetwork
from zoo.utils import parse_args, get_output_dirs, evaluate_policy

if __name__ == '__main__':
    args = parse_args()
    ckpt_dir, log_dir = get_output_dirs(os.path.dirname(__file__), 'trpo',
                                        args)

    env = gym.make('CartPole-v0')
    policy_fn = lambda: PolicyNetwork(env.observation_space.shape, env.
                                      action_space.n)
    vf_fn = lambda: ValueFunctionNetwork(env.observation_space.shape)
    agent = TRPO(
        env=env,
        policy_fn=policy_fn,
        vf_fn=vf_fn,
        lr_vf=1e-3,
        gamma=0.98,
        lambda_=0.96,
        delta=0.001,
        replay_buffer_size=250 * 8,
        policy_update_batch_size=512,
        vf_update_batch_size=512,
        vf_update_iterations=20,
Beispiel #2
0
import os

import gym

from rl.agents.ppo_penalty import PPOPenalty
from rl.loops import EpisodeTrainLoop
from rl.metrics import AverageEpisodeLength, AverageReturn
from zoo.cartpole.core import PolicyNetwork, ValueFunctionNetwork
from zoo.utils import parse_args, get_output_dirs, evaluate_policy

if __name__ == '__main__':
    args = parse_args()
    ckpt_dir, log_dir = get_output_dirs(os.path.dirname(__file__),
                                        'ppo_penalty', args)

    env = gym.make('CartPole-v0')
    policy_fn = lambda: PolicyNetwork(env.observation_space.shape, env.
                                      action_space.n)
    vf_fn = lambda: ValueFunctionNetwork(env.observation_space.shape)
    agent = PPOPenalty(
        env=env,
        policy_fn=policy_fn,
        vf_fn=vf_fn,
        lr_vf=1e-3,
        lr_policy=1e-3,
        gamma=0.98,
        lambda_=0.96,
        beta=1.0,
        kl_target=0.001,
        kl_tolerance=1.5,
        beta_update_factor=2,
Beispiel #3
0
import os

from rl.agents.alpha_zero import AlphaZero
from zoo.connect_4.core import Connect4, PolicyAndValueFunctionNetwork
from zoo.utils import parse_args, get_output_dirs

if __name__ == '__main__':
    args = parse_args()
    ckpt_dir, log_dir, replay_dir = get_output_dirs(os.path.dirname(__file__),
                                                    'alpha_zero', args,
                                                    ['ckpt', 'log', 'replay'])

    game_fn = lambda: Connect4()
    policy_and_vf_fn = lambda: PolicyAndValueFunctionNetwork(
        observation_shape=Connect4.observation_space.shape,
        n_actions=Connect4.action_space.n,
        l2=3e-4,
    )
    agent = AlphaZero(
        game_fn=game_fn,
        policy_and_vf_fn=policy_and_vf_fn,
        lr=1e-3,
        mcts_n_steps=500,
        mcts_tau=6 * 7,
        mcts_eta=0.03,
        mcts_epsilon=0.25,
        mcts_c_puct=1,
        n_self_play_workers=8,
        update_iterations=10_000_000,
        update_batch_size=64,
        replay_buffer_size=50_000,
Beispiel #4
0
import os

import gym

from rl.agents.vpg_gae import VPGGAE
from rl.loops import EpisodeTrainLoop
from rl.metrics import AverageReturn, AverageEpisodeLength
from zoo.cartpole.core import PolicyNetwork, ValueFunctionNetwork
from zoo.utils import parse_args, get_output_dirs, evaluate_policy

if __name__ == '__main__':
    args = parse_args()
    ckpt_dir, log_dir = get_output_dirs(os.path.dirname(__file__), 'vpg_gae',
                                        args)

    env = gym.make('CartPole-v0')
    policy_fn = lambda: PolicyNetwork(env.observation_space.shape, env.
                                      action_space.n)
    vf_fn = lambda: ValueFunctionNetwork(env.observation_space.shape)
    agent = VPGGAE(
        env=env,
        policy_fn=policy_fn,
        vf_fn=vf_fn,
        lr_policy=1e-3,
        lr_vf=1e-3,
        gamma=0.98,
        lambda_=0.96,
        vf_update_iterations=20,
        replay_buffer_size=250 * 2,
        policy_update_batch_size=256,
        vf_update_batch_size=256,
Beispiel #5
0
import os

from rl.agents.alpha_zero import AlphaZero
from zoo.tic_tac_toe.core import TicTacToe, PolicyAndValueFunctionNetwork
from zoo.utils import parse_args, get_output_dirs

if __name__ == '__main__':
    args = parse_args()
    ckpt_dir, log_dir = get_output_dirs(os.path.dirname(__file__), 'alpha_zero', args)

    game_fn = lambda: TicTacToe()
    policy_and_vf_fn = lambda: PolicyAndValueFunctionNetwork(
        observation_shape=TicTacToe.observation_space.shape,
        n_actions=TicTacToe.action_space.n,
        l2=1e-3,
    )
    agent = AlphaZero(
        game_fn=game_fn,
        policy_and_vf_fn=policy_and_vf_fn,
        lr=1e-3,
        mcts_n_steps=100,
        mcts_tau=9,
        mcts_eta=0.03,
        mcts_epsilon=0.25,
        mcts_c_puct=1,
        n_self_play_workers=8,
        update_iterations=10_000,
        update_batch_size=64,
        replay_buffer_size=10_000,
        ckpt_dir=ckpt_dir,
        log_dir=log_dir,