Exemplo n.º 1
0
import gym

from tf2rl.algos.vpg import VPG
from tf2rl.experiments.on_policy_trainer import OnPolicyTrainer
from tf2rl.envs.utils import is_discrete, get_act_dim


if __name__ == '__main__':
    parser = OnPolicyTrainer.get_argument()
    parser = VPG.get_argument(parser)
    parser.add_argument('--env-name', type=str,
                        default="Pendulum-v0")
    parser.set_defaults(test_interval=10240)
    parser.set_defaults(max_steps=int(1e7))
    parser.set_defaults(horizon=512)
    parser.set_defaults(batch_size=32)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = VPG(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
        is_discrete=is_discrete(env.action_space),
        max_action=None if is_discrete(
            env.action_space) else env.action_space.high[0],
        batch_size=args.batch_size,
        actor_units=[32, 32],
        critic_units=[32, 32],
        discount=0.9,
Exemplo n.º 2
0
 def setUpClass(cls):
     cls.agent = VPG(state_shape=cls.continuous_env.observation_space.shape,
                     action_dim=cls.continuous_env.action_space.low.size,
                     is_discrete=False,
                     gpu=-1)
Exemplo n.º 3
0
 def setUpClass(cls):
     cls.agent = VPG(state_shape=cls.discrete_env.observation_space.shape,
                     action_dim=cls.discrete_env.action_space.n,
                     is_discrete=True,
                     gpu=-1)
Exemplo n.º 4
0
from tf2rl.algos.vpg import VPG
from tf2rl.experiments.on_policy_trainer import OnPolicyTrainer
from tf2rl.envs.utils import is_discrete, get_act_dim, make

if __name__ == '__main__':
    parser = OnPolicyTrainer.get_argument()
    parser = VPG.get_argument(parser)
    parser.add_argument('--env-name', type=str, default="Pendulum-v0")
    parser.set_defaults(test_interval=10240)
    parser.set_defaults(max_steps=int(1e7))
    parser.set_defaults(horizon=512)
    parser.set_defaults(batch_size=32)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = make(args.env_name)
    test_env = make(args.env_name)
    policy = VPG(state_shape=env.observation_space.shape,
                 action_dim=get_act_dim(env.action_space),
                 is_discrete=is_discrete(env.action_space),
                 max_action=None if is_discrete(env.action_space) else
                 env.action_space.high[0],
                 batch_size=args.batch_size,
                 actor_units=(64, 64),
                 critic_units=(64, 64),
                 n_epoch=10,
                 lr_actor=3e-4,
                 lr_critic=3e-4,
                 hidden_activation_actor="tanh",
                 hidden_activation_critic="tanh",
                 discount=0.9,