コード例 #1
0
    def test_train(self):
        agent = DQN(
            state_shape=self.env.observation_space.shape,
            action_dim=self.env.action_space.n,
            memory_capacity=100,
            gpu=-1)
        from cpprb import ReplayBuffer
        replay_buffer = ReplayBuffer(
            obs_dim=self.env.observation_space.shape,
            act_dim=1,
            size=agent.memory_capacity)

        obs = self.env.reset()
        for _ in range(100):
            action = agent.get_action(obs)
            next_obs, reward, done, _ = self.env.step(action)
            replay_buffer.add(obs=obs, act=action, next_obs=next_obs, rew=reward, done=done)
            if done:
                next_obs = self.env.reset()
            obs = next_obs

        for _ in range(100):
            samples = replay_buffer.sample(agent.batch_size)
            agent.train(samples["obs"], samples["act"], samples["next_obs"],
                        samples["rew"], np.array(samples["done"], dtype=np.float64))
コード例 #2
0
 def __call__(self,
              env,
              name,
              memory_capacity=int(1e6),
              gpu=-1,
              noise_level=0.3):
     return DQN(name=name,
                enable_double_dqn=self.args.enable_double_dqn,
                enable_dueling_dqn=self.args.enable_dueling_dqn,
                enable_noisy_dqn=self.args.enable_noisy_dqn,
                enable_categorical_dqn=self.args.enable_categorical_dqn,
                state_shape=env.observation_space.shape,
                action_dim=env.action_space.n,
                n_warmup=self.n_warmup,
                target_replace_interval=self.target_replace_interval,
                batch_size=self.batch_size,
                memory_capacity=memory_capacity,
                discount=0.99,
                epsilon=1.,
                epsilon_min=0.1,
                epsilon_decay_step=self.epsilon_decay_rate,
                optimizer=self.optimizer,
                update_interval=4,
                q_func=self.QFunc,
                gpu=gpu)
コード例 #3
0
ファイル: test_apex.py プロジェクト: chentianba/coflowgym
    def test_run_discrete(self):
        from tf2rl.algos.dqn import DQN
        parser = DQN.get_argument(self.parser)
        parser.set_defaults(n_warmup=1)
        args, _ = parser.parse_known_args()

        def env_fn():
            return gym.make("CartPole-v0")

        def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
            return DQN(
                name=name,
                state_shape=env.observation_space.shape,
                action_dim=env.action_space.n,
                n_warmup=500,
                target_replace_interval=300,
                batch_size=32,
                memory_capacity=memory_capacity,
                discount=0.99,
                gpu=-1)

        def get_weights_fn(policy):
            return [policy.q_func.weights,
                    policy.q_func_target.weights]

        def set_weights_fn(policy, weights):
            q_func_weights, qfunc_target_weights = weights
            update_target_variables(
                policy.q_func.weights, q_func_weights, tau=1.)
            update_target_variables(
                policy.q_func_target.weights, qfunc_target_weights, tau=1.)

        run(args, env_fn, policy_fn, get_weights_fn, set_weights_fn)
コード例 #4
0
ファイル: test_dqn.py プロジェクト: chentianba/coflowgym
 def setUpClass(cls):
     super().setUpClass()
     cls.agent = DQN(state_shape=cls.discrete_env.observation_space.shape,
                     action_dim=cls.discrete_env.action_space.n,
                     batch_size=cls.batch_size,
                     epsilon=1.,
                     gpu=-1)
コード例 #5
0
    def test_run_discrete(self):
        from tf2rl.algos.dqn import DQN
        parser = DQN.get_argument(self.parser)
        parser.set_defaults(n_warmup=1)
        args, _ = parser.parse_known_args()

        run(args, env_fn_discrete, policy_fn_discrete, get_weights_fn_discrete,
            set_weights_fn_discrete)
コード例 #6
0
ファイル: test_dqn.py プロジェクト: chentianba/coflowgym
 def setUpClass(cls):
     super().setUpClass()
     cls.agent = DQN(state_shape=cls.discrete_env.observation_space.shape,
                     action_dim=cls.discrete_env.action_space.n,
                     batch_size=cls.batch_size,
                     enable_categorical_dqn=True,
                     enable_dueling_dqn=True,
                     epsilon=1.,
                     gpu=-1)
コード例 #7
0
 def setUpClass(cls):
     cls.env = gym.make("CartPole-v0")
     policy = DQN(state_shape=cls.env.observation_space.shape,
                  action_dim=cls.env.action_space.n,
                  memory_capacity=2**4)
     cls.replay_buffer = get_replay_buffer(policy, cls.env)
     cls.output_dir = os.path.join(os.path.dirname(__file__), "tests")
     if not os.path.isdir(cls.output_dir):
         os.makedirs(cls.output_dir)
コード例 #8
0
ファイル: test_apex.py プロジェクト: zhb0318/tf2rl
 def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1):
     return DQN(name=name,
                state_shape=env.observation_space.shape,
                action_dim=env.action_space.n,
                n_warmup=500,
                target_replace_interval=300,
                batch_size=32,
                memory_capacity=memory_capacity,
                discount=0.99,
                gpu=-1)
コード例 #9
0
def policy_fn_discrete(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
    from tf2rl.algos.dqn import DQN
    return DQN(
        name=name,
        state_shape=env.observation_space.shape,
        action_dim=env.action_space.n,
        n_warmup=500,
        target_replace_interval=300,
        batch_size=32,
        memory_capacity=memory_capacity,
        discount=0.99,
        gpu=-1)
コード例 #10
0
 def test_get_action(self):
     agent = DQN(
         state_shape=self.env.observation_space.shape,
         action_dim=self.env.action_space.n,
         gpu=-1)
     state = self.env.reset()
     agent.get_action(state, test=False)
     agent.get_action(state, test=True)
コード例 #11
0
ファイル: run_dqn.py プロジェクト: ymd-h/tf2rl
from tf2rl.algos.dqn import DQN
from tf2rl.experiments.trainer import Trainer
from tf2rl.envs.utils import make

if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = DQN.get_argument(parser)
    parser.set_defaults(test_interval=2000)
    parser.set_defaults(max_steps=100000)
    parser.set_defaults(gpu=-1)
    parser.set_defaults(n_warmup=500)
    parser.set_defaults(batch_size=32)
    parser.set_defaults(memory_capacity=int(1e4))
    parser.add_argument('--env-name', type=str, default="CartPole-v0")
    args = parser.parse_args()

    env = make(args.env_name)
    test_env = make(args.env_name)
    policy = DQN(enable_double_dqn=args.enable_double_dqn,
                 enable_dueling_dqn=args.enable_dueling_dqn,
                 enable_noisy_dqn=args.enable_noisy_dqn,
                 state_shape=env.observation_space.shape,
                 action_dim=env.action_space.n,
                 target_replace_interval=300,
                 discount=0.99,
                 gpu=args.gpu,
                 memory_capacity=args.memory_capacity,
                 batch_size=args.batch_size,
                 n_warmup=args.n_warmup)
    trainer = Trainer(policy, env, args, test_env=test_env)
    if args.evaluate:
コード例 #12
0
ファイル: run_dqn_atari.py プロジェクト: zw199502/tf2rl
import gym

from tf2rl.algos.dqn import DQN
from tf2rl.envs.atari_wrapper import wrap_dqn
from tf2rl.experiments.trainer import Trainer
from tf2rl.networks.atari_model import AtariQFunc as QFunc

if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = DQN.get_argument(parser)
    parser.add_argument('--env-name',
                        type=str,
                        default="SpaceInvadersNoFrameskip-v4")
    parser.set_defaults(episode_max_steps=108000)
    parser.set_defaults(test_interval=10000)
    parser.set_defaults(max_steps=int(1e9))
    parser.set_defaults(save_model_interval=500000)
    parser.set_defaults(gpu=0)
    parser.set_defaults(show_test_images=True)
    parser.set_defaults(memory_capacity=int(1e6))
    args = parser.parse_args()

    env = wrap_dqn(gym.make(args.env_name))
    test_env = wrap_dqn(gym.make(args.env_name), reward_clipping=False)
    # Following parameters are equivalent to DeepMind DQN paper
    # https://www.nature.com/articles/nature14236
    policy = DQN(
        enable_double_dqn=args.enable_double_dqn,
        enable_dueling_dqn=args.enable_dueling_dqn,
        enable_noisy_dqn=args.enable_noisy_dqn,
        state_shape=env.observation_space.shape,
コード例 #13
0
def get_weights_fn(policy):
    return [policy.q_func.weights, policy.q_func_target.weights]


def set_weights_fn(policy, weights):
    q_func_weights, qfunc_target_weights = weights
    update_target_variables(policy.q_func.weights, q_func_weights, tau=1.)
    update_target_variables(policy.q_func_target.weights,
                            qfunc_target_weights,
                            tau=1.)


if __name__ == '__main__':
    parser = apex_argument()
    parser = DQN.get_argument(parser)
    parser.add_argument('--atari', action='store_true')
    parser.add_argument('--env-name',
                        type=str,
                        default="SpaceInvadersNoFrameskip-v4")
    args = parser.parse_args()

    if args.atari:
        env_name = args.env_name
        n_warmup = 50000
        target_replace_interval = 10000
        batch_size = 32
        optimizer = tf.keras.optimizers.Adam(learning_rate=0.0000625,
                                             epsilon=1.5e-4)
        epsilon_decay_rate = int(1e6)
        QFunc = AtariQFunc
コード例 #14
0
import gym

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Flatten

from tf2rl.algos.dqn import DQN
from tf2rl.networks.noisy_dense import NoisyDense
from tf2rl.envs.atari_wrapper import wrap_dqn
from tf2rl.experiments.trainer import Trainer
from tf2rl.networks.dqn_model import AtariQFunc as QFunc

if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = DQN.get_argument(parser)
    parser.add_argument("--replay-buffer-size", type=int, default=int(1e6))
    parser.add_argument('--env-name',
                        type=str,
                        default="SpaceInvadersNoFrameskip-v4")
    parser.set_defaults(episode_max_steps=108000)
    parser.set_defaults(test_interval=10000)
    parser.set_defaults(max_steps=int(1e9))
    parser.set_defaults(save_model_interval=500000)
    parser.set_defaults(gpu=0)
    parser.set_defaults(show_test_images=True)
    args = parser.parse_args()

    env = wrap_dqn(gym.make(args.env_name))
    test_env = wrap_dqn(gym.make(args.env_name), reward_clipping=False)
    # Following parameters are equivalent to DeepMind DQN paper
    # https://www.nature.com/articles/nature14236
コード例 #15
0
            self(inputs=tf.constant(
                np.zeros(shape=input_shape, dtype=np.float64)))

    def call(self, inputs):
        features = self.conv1(inputs)
        features = self.conv2(features)
        features = self.conv3(features)
        features = self.flat(features)
        features = self.fc1(features)
        features = self.out(features)
        return features


if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = DQN.get_argument(parser)
    parser.add_argument("--replay-buffer-size", type=int, default=int(1e6))
    parser.set_defaults()
    parser.set_defaults(test_interval=10000)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = wrap_deepmind(gym.make('SpaceInvaders-v0'),
                        frame_stack=True,
                        scale=True)
    test_env = wrap_deepmind(gym.make('SpaceInvaders-v0'),
                             frame_stack=True,
                             scale=True)
    # Following parameters are equivalent to DeepMind DQN paper
    # https://www.nature.com/articles/nature14236
    policy = DQN(enable_double_dqn=args.enable_double_dqn,
コード例 #16
0
 def test__init__(self):
     DQN(state_shape=self.env.observation_space.shape,
         action_dim=self.env.action_space.n,
         gpu=-1)