import gym from tf2rl.algos.vpg import VPG from tf2rl.experiments.on_policy_trainer import OnPolicyTrainer from tf2rl.envs.utils import is_discrete, get_act_dim if __name__ == '__main__': parser = OnPolicyTrainer.get_argument() parser = VPG.get_argument(parser) parser.add_argument('--env-name', type=str, default="Pendulum-v0") parser.set_defaults(test_interval=10240) parser.set_defaults(max_steps=int(1e7)) parser.set_defaults(horizon=512) parser.set_defaults(batch_size=32) parser.set_defaults(gpu=-1) args = parser.parse_args() env = gym.make(args.env_name) test_env = gym.make(args.env_name) policy = VPG( state_shape=env.observation_space.shape, action_dim=get_act_dim(env.action_space), is_discrete=is_discrete(env.action_space), max_action=None if is_discrete( env.action_space) else env.action_space.high[0], batch_size=args.batch_size, actor_units=[32, 32], critic_units=[32, 32], discount=0.9,
def setUpClass(cls): cls.agent = VPG(state_shape=cls.continuous_env.observation_space.shape, action_dim=cls.continuous_env.action_space.low.size, is_discrete=False, gpu=-1)
def setUpClass(cls): cls.agent = VPG(state_shape=cls.discrete_env.observation_space.shape, action_dim=cls.discrete_env.action_space.n, is_discrete=True, gpu=-1)
from tf2rl.algos.vpg import VPG from tf2rl.experiments.on_policy_trainer import OnPolicyTrainer from tf2rl.envs.utils import is_discrete, get_act_dim, make if __name__ == '__main__': parser = OnPolicyTrainer.get_argument() parser = VPG.get_argument(parser) parser.add_argument('--env-name', type=str, default="Pendulum-v0") parser.set_defaults(test_interval=10240) parser.set_defaults(max_steps=int(1e7)) parser.set_defaults(horizon=512) parser.set_defaults(batch_size=32) parser.set_defaults(gpu=-1) args = parser.parse_args() env = make(args.env_name) test_env = make(args.env_name) policy = VPG(state_shape=env.observation_space.shape, action_dim=get_act_dim(env.action_space), is_discrete=is_discrete(env.action_space), max_action=None if is_discrete(env.action_space) else env.action_space.high[0], batch_size=args.batch_size, actor_units=(64, 64), critic_units=(64, 64), n_epoch=10, lr_actor=3e-4, lr_critic=3e-4, hidden_activation_actor="tanh", hidden_activation_critic="tanh", discount=0.9,