Пример #1
0
 def get_argument(parser=None):
     parser = Trainer.get_argument(parser)
     parser.add_argument(
         '--expert-path-dir',
         default=None,
         help='Path to directory that contains expert trajectories')
     return parser
Пример #2
0
 def get_argument(parser=None):
     parser = Trainer.get_argument(parser)
     parser.add_argument('--gpu', type=int, default=0, help='GPU id')
     parser.add_argument("--max-iter", type=int, default=100)
     parser.add_argument("--horizon", type=int, default=20)
     parser.add_argument("--n-sample", type=int, default=1000)
     parser.add_argument("--n-random-rollout", type=int, default=1000)
     parser.add_argument("--batch-size", type=int, default=512)
     return parser
Пример #3
0
 def test_empty_args(self):
     """
     Test empty args {}
     """
     env = gym.make("Pendulum-v0")
     test_env = gym.make("Pendulum-v0")
     policy = DDPG(state_shape=env.observation_space.shape,
                   action_dim=env.action_space.high.size,
                   gpu=-1,
                   memory_capacity=1000,
                   max_action=env.action_space.high[0],
                   batch_size=32,
                   n_warmup=10)
     Trainer(policy, env, {}, test_env=test_env)
Пример #4
0
 def test_invalid_args(self):
     """
     Test with invalid args
     """
     env = gym.make("Pendulum-v0")
     test_env = gym.make("Pendulum-v0")
     policy = DDPG(state_shape=env.observation_space.shape,
                   action_dim=env.action_space.high.size,
                   gpu=-1,
                   memory_capacity=1000,
                   max_action=env.action_space.high[0],
                   batch_size=32,
                   n_warmup=10)
     with self.assertRaises(ValueError):
         Trainer(policy, env, {"NOT_EXISTING_OPTIONS": 1}, test_env=test_env)
Пример #5
0
 def test_with_args(self):
     """
     Test with args
     """
     max_steps = 400
     env = gym.make("Pendulum-v0")
     test_env = gym.make("Pendulum-v0")
     policy = DDPG(state_shape=env.observation_space.shape,
                   action_dim=env.action_space.high.size,
                   gpu=-1,
                   memory_capacity=1000,
                   max_action=env.action_space.high[0],
                   batch_size=32,
                   n_warmup=10)
     trainer = Trainer(policy, env, {"max_steps": max_steps}, test_env=test_env)
     self.assertEqual(trainer._max_steps, max_steps)
Пример #6
0
    def get_argument(parser=None):
        """
        Create or update argument parser for command line program

        Args:
            parser (argparse.ArgParser, optional): argument parser

        Returns:
            argparse.ArgParser: argument parser
        """
        parser = Trainer.get_argument(parser)
        parser.add_argument(
            '--expert-path-dir',
            default=None,
            help='Path to directory that contains expert trajectories')
        return parser
Пример #7
0
def main():
    parser = Trainer.get_argument()
    parser = OnPolicyTrainer.get_argument(parser)
    parser = SAC.get_argument(parser)
    parser = PPO.get_argument(parser)
    parser = TD3.get_argument(parser)

    parser.add_argument('--SAC', action='store_true')
    parser.add_argument('--PPO', action='store_true')
    parser.add_argument('--TD3', action='store_true')
    parser.add_argument('--DEBUG', action='store_true')
    parser.add_argument('--env', type=int, default=0)

    parser.set_defaults(batch_size=32)  #100
    parser.set_defaults(n_warmup=10000)  #10000
    parser.set_defaults(max_steps=2e6)
    parser.set_defaults(gpu=0)
    parser.set_defaults(test_interval=200 * 100)
    parser.set_defaults(test_episodes=3)

    args = parser.parse_args()
    print(vars(args))
    run(parser)
Пример #8
0
def main():
    dm_envs = {
        'finger': ['finger', 'spin', 2],
        'cartpole': ['cartpole', 'swingup', 8],
        'reacher': ['reacher', 'easy', 4],
        'cheetah': ['cheetah', 'run', 4],
        'walker': ['walker', 'walk', 2],
        'ball': ['ball_in_cup', 'catch', 4],
        'humanoid': ['humanoid', 'stand', 4],
        'bring_ball': ['manipulator', 'bring_ball', 4],
        'bring_peg': ['manipulator', 'bring_peg', 4],
        'insert_ball': ['manipulator', 'insert_ball', 4],
        'insert_peg': ['manipulator', 'insert_peg', 4]}

    parser = Trainer.get_argument()
    parser = CURL.get_argument(parser)
    parser.add_argument('--env-name', type=str, default="cartpole", choices=dm_envs.keys())
    parser.add_argument('--seed', type=int, default=1)
    parser.set_defaults(save_summary_interval=50)
    parser.set_defaults(memory_capacity=int(1e5))
    args = parser.parse_args()

    domain_name, task_name, action_repeat = dm_envs[args.env_name]
    original_obs_shape = (100, 100, 9)
    input_obs_shape = (84, 84, 9)

    def make_env():
        return DMCWrapper(
            dmc2gym.make(
                domain_name=domain_name,
                task_name=task_name,
                seed=args.seed,
                visualize_reward=False,
                from_pixels=True,
                height=100,
                width=100,
                frame_skip=action_repeat,
                channels_first=False),
            obs_shape=original_obs_shape,
            k=3,
            channel_first=False)

    env = make_env()
    test_env = make_env()

    # see Table 3 of CURL paper
    lr_sac = lr_encoder = 2e-4 if args.env_name == "cheetah" else 1e-3

    policy = CURL(
        obs_shape=input_obs_shape,
        action_dim=env.action_space.high.size,
        gpu=args.gpu,
        memory_capacity=args.memory_capacity,
        n_warmup=int(1e3),
        max_action=env.action_space.high[0],
        batch_size=512,
        actor_units=(1024, 1024),
        critic_units=(1024, 1024),
        lr_sac=lr_sac,
        lr_encoder=lr_encoder,
        lr_alpha=1e-4,
        tau_critic=0.01,
        init_temperature=0.1,
        auto_alpha=True,
        stop_q_grad=args.stop_q_grad)

    trainer = Trainer(policy, env, args, test_env=test_env)
    if args.evaluate:
        trainer.evaluate_policy_continuously()
    else:
        trainer()
Пример #9
0
from tf2rl.algos.dqn import DQN
from tf2rl.experiments.trainer import Trainer
from tf2rl.envs.utils import make

if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = DQN.get_argument(parser)
    parser.set_defaults(test_interval=2000)
    parser.set_defaults(max_steps=100000)
    parser.set_defaults(gpu=-1)
    parser.set_defaults(n_warmup=500)
    parser.set_defaults(batch_size=32)
    parser.set_defaults(memory_capacity=int(1e4))
    parser.add_argument('--env-name', type=str, default="CartPole-v0")
    args = parser.parse_args()

    env = make(args.env_name)
    test_env = make(args.env_name)
    policy = DQN(enable_double_dqn=args.enable_double_dqn,
                 enable_dueling_dqn=args.enable_dueling_dqn,
                 enable_noisy_dqn=args.enable_noisy_dqn,
                 state_shape=env.observation_space.shape,
                 action_dim=env.action_space.n,
                 target_replace_interval=300,
                 discount=0.99,
                 gpu=args.gpu,
                 memory_capacity=args.memory_capacity,
                 batch_size=args.batch_size,
                 n_warmup=args.n_warmup)
    trainer = Trainer(policy, env, args, test_env=test_env)
    if args.evaluate:
Пример #10
0
import gym

from tf2rl.algos.dqn import DQN
from tf2rl.envs.atari_wrapper import wrap_dqn
from tf2rl.experiments.trainer import Trainer
from tf2rl.networks.atari_model import AtariQFunc as QFunc

if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = DQN.get_argument(parser)
    parser.add_argument('--env-name',
                        type=str,
                        default="SpaceInvadersNoFrameskip-v4")
    parser.set_defaults(episode_max_steps=108000)
    parser.set_defaults(test_interval=10000)
    parser.set_defaults(max_steps=int(1e9))
    parser.set_defaults(save_model_interval=500000)
    parser.set_defaults(gpu=0)
    parser.set_defaults(show_test_images=True)
    parser.set_defaults(memory_capacity=int(1e6))
    args = parser.parse_args()

    env = wrap_dqn(gym.make(args.env_name))
    test_env = wrap_dqn(gym.make(args.env_name), reward_clipping=False)
    # Following parameters are equivalent to DeepMind DQN paper
    # https://www.nature.com/articles/nature14236
    policy = DQN(
        enable_double_dqn=args.enable_double_dqn,
        enable_dueling_dqn=args.enable_dueling_dqn,
        enable_noisy_dqn=args.enable_noisy_dqn,
        state_shape=env.observation_space.shape,
Пример #11
0
def main():
    dm_envs = {
        'finger': ['finger', 'spin', 2],
        'cartpole': ['cartpole', 'balance', 4],
        'reacher': ['reacher', 'easy', 4],
        'cheetah': ['cheetah', 'run', 4],
        'walker': ['walker', 'walk', 2],
        'ball': ['ball_in_cup', 'catch', 4],
        'humanoid': ['humanoid', 'stand', 4],
        'bring_ball': ['manipulator', 'bring_ball', 4],
        'bring_peg': ['manipulator', 'bring_peg', 4],
        'insert_ball': ['manipulator', 'insert_ball', 4],
        'insert_peg': ['manipulator', 'insert_peg', 4]
    }

    parser = Trainer.get_argument()
    parser = CURLSAC.get_argument(parser)
    parser.add_argument('--env-name',
                        type=str,
                        default="cartpole",
                        choices=dm_envs.keys())
    parser.add_argument('--seed', type=int, default=1)
    parser.set_defaults(batch_size=256)
    parser.set_defaults(n_warmup=10000)
    parser.set_defaults(max_steps=3e6)
    parser.set_defaults(save_summary_interval=100)
    args = parser.parse_args()

    domain_name, task_name, action_repeat = dm_envs[args.env_name]
    original_obs_shape = (100, 100, 9)
    input_obs_shape = (84, 84, 9)

    def make_env():
        return DMCWrapper(dmc2gym.make(domain_name=domain_name,
                                       task_name=task_name,
                                       seed=args.seed,
                                       visualize_reward=False,
                                       from_pixels=True,
                                       height=100,
                                       width=100,
                                       frame_skip=action_repeat,
                                       channels_first=False),
                          obs_shape=original_obs_shape,
                          k=3,
                          channel_first=False)

    env = make_env()
    test_env = make_env()

    policy = CURLSAC(obs_shape=input_obs_shape,
                     action_dim=env.action_space.high.size,
                     gpu=args.gpu,
                     memory_capacity=int(1e5),
                     n_warmup=int(1e3),
                     max_action=env.action_space.high[0],
                     batch_size=10,
                     alpha=args.alpha,
                     auto_alpha=args.auto_alpha)

    trainer = Trainer(policy, env, args, test_env=test_env)
    if args.evaluate:
        trainer.evaluate_policy_continuously()
    else:
        trainer()
Пример #12
0
def run(parser):

    args = parser.parse_args()

    if args.gpu < 0:
        tf.config.experimental.set_visible_devices([], 'GPU')
    else:
        physical_devices = tf.config.list_physical_devices('GPU')
        tf.config.set_visible_devices(physical_devices[args.gpu], 'GPU')
        tf.config.experimental.set_virtual_device_configuration(
            physical_devices[args.gpu], [
                tf.config.experimental.VirtualDeviceConfiguration(
                    memory_limit=1024 * 3)
            ])

    if args.env == 200:
        envname = 'ScratchItchPR2X'
    elif args.env == 201:
        envname = 'DressingPR2X'
    elif args.env == 202:
        envname = 'BedBathingPR2X'

    logdir = f'MFBox_Assistive'
    if args.SAC:
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'SAC on {envname}')
    elif args.PPO:
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'PPO on {envname}')
    elif args.TD3:
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'TD3 on {envname}')
    elif args.DEBUG:
        logdir = f'DEBUG_Assistive'
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'DEBUG on {envname}')
    else:
        print('PLEASE INDICATE THE ALGORITHM !!')

    if not os.path.exists(logdir):
        os.makedirs(logdir)
    parser.set_defaults(logdir=logdir)
    args = parser.parse_args()

    env = gym.make(f'{envname}-v0')
    #test_env = Monitor(env,logdir,force=True)
    test_env = gym.make(f'{envname}-v0')

    if args.SAC:

        policy = SAC(state_shape=env.observation_space.shape,
                     action_dim=env.action_space.high.size,
                     gpu=args.gpu,
                     memory_capacity=args.memory_capacity,
                     max_action=env.action_space.high[0],
                     batch_size=args.batch_size,
                     n_warmup=args.n_warmup,
                     alpha=args.alpha,
                     auto_alpha=args.auto_alpha)
        trainer = Trainer(policy, env, args, test_env=test_env)

    elif args.PPO:
        policy = PPO(state_shape=env.observation_space.shape,
                     action_dim=get_act_dim(env.action_space),
                     is_discrete=is_discrete(env.action_space),
                     max_action=None if is_discrete(env.action_space) else
                     env.action_space.high[0],
                     batch_size=args.batch_size,
                     actor_units=(64, 64),
                     critic_units=(64, 64),
                     n_epoch=10,
                     lr_actor=3e-4,
                     lr_critic=3e-4,
                     hidden_activation_actor="tanh",
                     hidden_activation_critic="tanh",
                     discount=0.99,
                     lam=0.95,
                     entropy_coef=0.,
                     horizon=args.horizon,
                     normalize_adv=args.normalize_adv,
                     enable_gae=args.enable_gae,
                     gpu=args.gpu)
        trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)

    elif args.TD3:
        policy = TD3(state_shape=env.observation_space.shape,
                     action_dim=env.action_space.high.size,
                     gpu=args.gpu,
                     memory_capacity=args.memory_capacity,
                     max_action=env.action_space.high[0],
                     batch_size=args.batch_size,
                     n_warmup=args.n_warmup)
        trainer = Trainer(policy, env, args, test_env=test_env)

    elif args.DEBUG:

        policy = SAC(state_shape=env.observation_space.shape,
                     action_dim=env.action_space.high.size,
                     gpu=args.gpu,
                     memory_capacity=args.memory_capacity,
                     max_action=env.action_space.high[0],
                     batch_size=args.batch_size,
                     n_warmup=100,
                     alpha=args.alpha,
                     auto_alpha=args.auto_alpha)
        parser.set_defaults(test_interval=200)
        args = parser.parse_args()

        trainer = Trainer(policy, env, args, test_env=None)

    trainer()