def test_run_discrete(self): from tf2rl.algos.dqn import DQN parser = DQN.get_argument(self.parser) parser.set_defaults(n_warmup=1) args, _ = parser.parse_known_args() def env_fn(): return gym.make("CartPole-v0") def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs): return DQN( name=name, state_shape=env.observation_space.shape, action_dim=env.action_space.n, n_warmup=500, target_replace_interval=300, batch_size=32, memory_capacity=memory_capacity, discount=0.99, gpu=-1) def get_weights_fn(policy): return [policy.q_func.weights, policy.q_func_target.weights] def set_weights_fn(policy, weights): q_func_weights, qfunc_target_weights = weights update_target_variables( policy.q_func.weights, q_func_weights, tau=1.) update_target_variables( policy.q_func_target.weights, qfunc_target_weights, tau=1.) run(args, env_fn, policy_fn, get_weights_fn, set_weights_fn)
def test_run_discrete(self): from tf2rl.algos.dqn import DQN parser = DQN.get_argument(self.parser) parser.set_defaults(n_warmup=1) args, _ = parser.parse_known_args() run(args, env_fn_discrete, policy_fn_discrete, get_weights_fn_discrete, set_weights_fn_discrete)
from tf2rl.algos.dqn import DQN from tf2rl.experiments.trainer import Trainer from tf2rl.envs.utils import make if __name__ == '__main__': parser = Trainer.get_argument() parser = DQN.get_argument(parser) parser.set_defaults(test_interval=2000) parser.set_defaults(max_steps=100000) parser.set_defaults(gpu=-1) parser.set_defaults(n_warmup=500) parser.set_defaults(batch_size=32) parser.set_defaults(memory_capacity=int(1e4)) parser.add_argument('--env-name', type=str, default="CartPole-v0") args = parser.parse_args() env = make(args.env_name) test_env = make(args.env_name) policy = DQN(enable_double_dqn=args.enable_double_dqn, enable_dueling_dqn=args.enable_dueling_dqn, enable_noisy_dqn=args.enable_noisy_dqn, state_shape=env.observation_space.shape, action_dim=env.action_space.n, target_replace_interval=300, discount=0.99, gpu=args.gpu, memory_capacity=args.memory_capacity, batch_size=args.batch_size, n_warmup=args.n_warmup) trainer = Trainer(policy, env, args, test_env=test_env) if args.evaluate: