def main(): parser = Trainer.get_argument() parser = OnPolicyTrainer.get_argument(parser) parser = SAC.get_argument(parser) parser = PPO.get_argument(parser) parser = TD3.get_argument(parser) parser.add_argument('--SAC', action='store_true') parser.add_argument('--PPO', action='store_true') parser.add_argument('--TD3', action='store_true') parser.add_argument('--DEBUG', action='store_true') parser.add_argument('--env', type=int, default=0) parser.set_defaults(batch_size=32) #100 parser.set_defaults(n_warmup=10000) #10000 parser.set_defaults(max_steps=2e6) parser.set_defaults(gpu=0) parser.set_defaults(test_interval=200 * 100) parser.set_defaults(test_episodes=3) args = parser.parse_args() print(vars(args)) run(parser)
from tf2rl.algos.ppo import PPO from tf2rl.experiments.on_policy_trainer import OnPolicyTrainer from tf2rl.envs.utils import is_discrete, get_act_dim from arm_env import ArmEnvironment import rospy import numpy as np normalise_obs = False static_goal = True testing = False slow_step = True rospy.init_node("RL_agent") parser = OnPolicyTrainer.get_argument() parser = PPO.get_argument(parser) slow_suffix = "_slow" if(slow_step) else "" #parser.set_defaults(max_steps=5) if(static_goal and normalise_obs): parser.set_defaults(model_dir='model_PPO_static_normed'+slow_suffix) parser.set_defaults(logdir='results/PPO_static_normed'+slow_suffix) elif(static_goal): parser.set_defaults(model_dir='model_PPO_static'+slow_suffix) parser.set_defaults(logdir='results/PPO_static'+slow_suffix) elif(normalise_obs): parser.set_defaults(model_dir='model_PPO_normed'+slow_suffix) parser.set_defaults(logdir='results/PPO_normed'+slow_suffix) else: parser.set_defaults(model_dir='model_PPO'+slow_suffix) parser.set_defaults(logdir='results/PPO'+slow_suffix) parser.set_defaults(normalise_obs=normalise_obs)
import gym from tf2rl.algos.ppo import PPO from tf2rl.envs.atari_wrapper import wrap_dqn from tf2rl.experiments.on_policy_trainer import OnPolicyTrainer from tf2rl.networks.atari_model import AtariCategoricalActorCritic if __name__ == '__main__': parser = OnPolicyTrainer.get_argument() parser = PPO.get_argument(parser) parser.add_argument('--env-name', type=str, default="PongNoFrameskip-v4") parser.set_defaults(episode_max_steps=108000) parser.set_defaults(horizon=1024) parser.set_defaults(test_interval=204800) parser.set_defaults(max_steps=int(1e9)) parser.set_defaults(save_model_interval=2048000) parser.set_defaults(gpu=0) parser.set_defaults(show_test_images=True) args = parser.parse_args() env = wrap_dqn(gym.make(args.env_name)) test_env = wrap_dqn(gym.make(args.env_name), reward_clipping=False) state_shape = env.observation_space.shape action_dim = env.action_space.n actor_critic = AtariCategoricalActorCritic(state_shape=state_shape, action_dim=action_dim) policy = PPO(state_shape=state_shape, action_dim=action_dim,
def setUpClass(cls): cls.agent = PPO(state_shape=cls.continuous_env.observation_space.shape, action_dim=cls.continuous_env.action_space.low.size, is_discrete=False, gpu=-1)
def setUpClass(cls): cls.agent = PPO(state_shape=cls.discrete_env.observation_space.shape, action_dim=cls.discrete_env.action_space.n, is_discrete=True, gpu=-1)
def run(parser): args = parser.parse_args() if args.gpu < 0: tf.config.experimental.set_visible_devices([], 'GPU') else: physical_devices = tf.config.list_physical_devices('GPU') tf.config.set_visible_devices(physical_devices[args.gpu], 'GPU') tf.config.experimental.set_virtual_device_configuration( physical_devices[args.gpu], [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=1024 * 3) ]) if args.env == 200: envname = 'ScratchItchPR2X' elif args.env == 201: envname = 'DressingPR2X' elif args.env == 202: envname = 'BedBathingPR2X' logdir = f'MFBox_Assistive' if args.SAC: wandb.init(config=vars(args), project="Assistive Gym", name=f'SAC on {envname}') elif args.PPO: wandb.init(config=vars(args), project="Assistive Gym", name=f'PPO on {envname}') elif args.TD3: wandb.init(config=vars(args), project="Assistive Gym", name=f'TD3 on {envname}') elif args.DEBUG: logdir = f'DEBUG_Assistive' wandb.init(config=vars(args), project="Assistive Gym", name=f'DEBUG on {envname}') else: print('PLEASE INDICATE THE ALGORITHM !!') if not os.path.exists(logdir): os.makedirs(logdir) parser.set_defaults(logdir=logdir) args = parser.parse_args() env = gym.make(f'{envname}-v0') #test_env = Monitor(env,logdir,force=True) test_env = gym.make(f'{envname}-v0') if args.SAC: policy = SAC(state_shape=env.observation_space.shape, action_dim=env.action_space.high.size, gpu=args.gpu, memory_capacity=args.memory_capacity, max_action=env.action_space.high[0], batch_size=args.batch_size, n_warmup=args.n_warmup, alpha=args.alpha, auto_alpha=args.auto_alpha) trainer = Trainer(policy, env, args, test_env=test_env) elif args.PPO: policy = PPO(state_shape=env.observation_space.shape, action_dim=get_act_dim(env.action_space), is_discrete=is_discrete(env.action_space), max_action=None if is_discrete(env.action_space) else env.action_space.high[0], batch_size=args.batch_size, actor_units=(64, 64), critic_units=(64, 64), n_epoch=10, lr_actor=3e-4, lr_critic=3e-4, hidden_activation_actor="tanh", hidden_activation_critic="tanh", discount=0.99, lam=0.95, entropy_coef=0., horizon=args.horizon, normalize_adv=args.normalize_adv, enable_gae=args.enable_gae, gpu=args.gpu) trainer = OnPolicyTrainer(policy, env, args, test_env=test_env) elif args.TD3: policy = TD3(state_shape=env.observation_space.shape, action_dim=env.action_space.high.size, gpu=args.gpu, memory_capacity=args.memory_capacity, max_action=env.action_space.high[0], batch_size=args.batch_size, n_warmup=args.n_warmup) trainer = Trainer(policy, env, args, test_env=test_env) elif args.DEBUG: policy = SAC(state_shape=env.observation_space.shape, action_dim=env.action_space.high.size, gpu=args.gpu, memory_capacity=args.memory_capacity, max_action=env.action_space.high[0], batch_size=args.batch_size, n_warmup=100, alpha=args.alpha, auto_alpha=args.auto_alpha) parser.set_defaults(test_interval=200) args = parser.parse_args() trainer = Trainer(policy, env, args, test_env=None) trainer()