示例#1
0
    def test_run_discrete(self):
        from tf2rl.algos.dqn import DQN
        parser = DQN.get_argument(self.parser)
        parser.set_defaults(n_warmup=1)
        args, _ = parser.parse_known_args()

        def env_fn():
            return gym.make("CartPole-v0")

        def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
            return DQN(
                name=name,
                state_shape=env.observation_space.shape,
                action_dim=env.action_space.n,
                n_warmup=500,
                target_replace_interval=300,
                batch_size=32,
                memory_capacity=memory_capacity,
                discount=0.99,
                gpu=-1)

        def get_weights_fn(policy):
            return [policy.q_func.weights,
                    policy.q_func_target.weights]

        def set_weights_fn(policy, weights):
            q_func_weights, qfunc_target_weights = weights
            update_target_variables(
                policy.q_func.weights, q_func_weights, tau=1.)
            update_target_variables(
                policy.q_func_target.weights, qfunc_target_weights, tau=1.)

        run(args, env_fn, policy_fn, get_weights_fn, set_weights_fn)
示例#2
0
    def test_run_discrete(self):
        from tf2rl.algos.dqn import DQN
        parser = DQN.get_argument(self.parser)
        parser.set_defaults(n_warmup=1)
        args, _ = parser.parse_known_args()

        run(args, env_fn_discrete, policy_fn_discrete, get_weights_fn_discrete,
            set_weights_fn_discrete)
示例#3
0
文件: run_dqn.py 项目: ymd-h/tf2rl
from tf2rl.algos.dqn import DQN
from tf2rl.experiments.trainer import Trainer
from tf2rl.envs.utils import make

if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = DQN.get_argument(parser)
    parser.set_defaults(test_interval=2000)
    parser.set_defaults(max_steps=100000)
    parser.set_defaults(gpu=-1)
    parser.set_defaults(n_warmup=500)
    parser.set_defaults(batch_size=32)
    parser.set_defaults(memory_capacity=int(1e4))
    parser.add_argument('--env-name', type=str, default="CartPole-v0")
    args = parser.parse_args()

    env = make(args.env_name)
    test_env = make(args.env_name)
    policy = DQN(enable_double_dqn=args.enable_double_dqn,
                 enable_dueling_dqn=args.enable_dueling_dqn,
                 enable_noisy_dqn=args.enable_noisy_dqn,
                 state_shape=env.observation_space.shape,
                 action_dim=env.action_space.n,
                 target_replace_interval=300,
                 discount=0.99,
                 gpu=args.gpu,
                 memory_capacity=args.memory_capacity,
                 batch_size=args.batch_size,
                 n_warmup=args.n_warmup)
    trainer = Trainer(policy, env, args, test_env=test_env)
    if args.evaluate: