args = parser.parse_args()

env = ArmEnvironment(static_goal=True,slow_step=slow_step)
test_env = ArmEnvironment(static_goal=True,slow_step=slow_step)

policy = PPO(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
        is_discrete=is_discrete(env.action_space),
        max_action=None if is_discrete(
            env.action_space) else env.action_space.high[0],
        batch_size=args.batch_size,
        actor_units=(64, 64),
        critic_units=(64, 64),
        n_epoch=10,
        lr_actor=3e-4,
        lr_critic=3e-4,
        hidden_activation_actor="tanh",
        hidden_activation_critic="tanh",
        discount=0.99,
        lam=0.95,
        entropy_coef=0.001,
        horizon=args.horizon,
        normalize_adv=args.normalize_adv,
        enable_gae=args.enable_gae,
        gpu=args.gpu)
trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)

if(testing):
    trainer.evaluate_policy()
else:
示例#2
0
    args = parser.parse_args()

    env = wrap_dqn(gym.make(args.env_name))
    test_env = wrap_dqn(gym.make(args.env_name), reward_clipping=False)

    state_shape = env.observation_space.shape
    action_dim = env.action_space.n

    actor_critic = AtariCategoricalActorCritic(state_shape=state_shape,
                                               action_dim=action_dim)

    policy = PPO(state_shape=state_shape,
                 action_dim=action_dim,
                 is_discrete=True,
                 actor_critic=actor_critic,
                 batch_size=args.batch_size,
                 n_epoch=3,
                 lr_actor=2.5e-4,
                 lr_critic=2.5e-4,
                 discount=0.99,
                 lam=0.95,
                 horizon=args.horizon,
                 normalize_adv=args.normalize_adv,
                 enable_gae=args.enable_gae,
                 gpu=args.gpu)
    trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)
    if args.evaluate:
        trainer.evaluate_policy_continuously()
    else:
        trainer()
示例#3
0
 def setUpClass(cls):
     cls.agent = PPO(state_shape=cls.discrete_env.observation_space.shape,
                     action_dim=cls.discrete_env.action_space.n,
                     is_discrete=True,
                     gpu=-1)
示例#4
0
 def setUpClass(cls):
     cls.agent = PPO(state_shape=cls.continuous_env.observation_space.shape,
                     action_dim=cls.continuous_env.action_space.low.size,
                     is_discrete=False,
                     gpu=-1)
示例#5
0
文件: MFBox.py 项目: Chenaah/RAMCO
def run(parser):

    args = parser.parse_args()

    if args.gpu < 0:
        tf.config.experimental.set_visible_devices([], 'GPU')
    else:
        physical_devices = tf.config.list_physical_devices('GPU')
        tf.config.set_visible_devices(physical_devices[args.gpu], 'GPU')
        tf.config.experimental.set_virtual_device_configuration(
            physical_devices[args.gpu], [
                tf.config.experimental.VirtualDeviceConfiguration(
                    memory_limit=1024 * 3)
            ])

    if args.env == 200:
        envname = 'ScratchItchPR2X'
    elif args.env == 201:
        envname = 'DressingPR2X'
    elif args.env == 202:
        envname = 'BedBathingPR2X'

    logdir = f'MFBox_Assistive'
    if args.SAC:
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'SAC on {envname}')
    elif args.PPO:
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'PPO on {envname}')
    elif args.TD3:
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'TD3 on {envname}')
    elif args.DEBUG:
        logdir = f'DEBUG_Assistive'
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'DEBUG on {envname}')
    else:
        print('PLEASE INDICATE THE ALGORITHM !!')

    if not os.path.exists(logdir):
        os.makedirs(logdir)
    parser.set_defaults(logdir=logdir)
    args = parser.parse_args()

    env = gym.make(f'{envname}-v0')
    #test_env = Monitor(env,logdir,force=True)
    test_env = gym.make(f'{envname}-v0')

    if args.SAC:

        policy = SAC(state_shape=env.observation_space.shape,
                     action_dim=env.action_space.high.size,
                     gpu=args.gpu,
                     memory_capacity=args.memory_capacity,
                     max_action=env.action_space.high[0],
                     batch_size=args.batch_size,
                     n_warmup=args.n_warmup,
                     alpha=args.alpha,
                     auto_alpha=args.auto_alpha)
        trainer = Trainer(policy, env, args, test_env=test_env)

    elif args.PPO:
        policy = PPO(state_shape=env.observation_space.shape,
                     action_dim=get_act_dim(env.action_space),
                     is_discrete=is_discrete(env.action_space),
                     max_action=None if is_discrete(env.action_space) else
                     env.action_space.high[0],
                     batch_size=args.batch_size,
                     actor_units=(64, 64),
                     critic_units=(64, 64),
                     n_epoch=10,
                     lr_actor=3e-4,
                     lr_critic=3e-4,
                     hidden_activation_actor="tanh",
                     hidden_activation_critic="tanh",
                     discount=0.99,
                     lam=0.95,
                     entropy_coef=0.,
                     horizon=args.horizon,
                     normalize_adv=args.normalize_adv,
                     enable_gae=args.enable_gae,
                     gpu=args.gpu)
        trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)

    elif args.TD3:
        policy = TD3(state_shape=env.observation_space.shape,
                     action_dim=env.action_space.high.size,
                     gpu=args.gpu,
                     memory_capacity=args.memory_capacity,
                     max_action=env.action_space.high[0],
                     batch_size=args.batch_size,
                     n_warmup=args.n_warmup)
        trainer = Trainer(policy, env, args, test_env=test_env)

    elif args.DEBUG:

        policy = SAC(state_shape=env.observation_space.shape,
                     action_dim=env.action_space.high.size,
                     gpu=args.gpu,
                     memory_capacity=args.memory_capacity,
                     max_action=env.action_space.high[0],
                     batch_size=args.batch_size,
                     n_warmup=100,
                     alpha=args.alpha,
                     auto_alpha=args.auto_alpha)
        parser.set_defaults(test_interval=200)
        args = parser.parse_args()

        trainer = Trainer(policy, env, args, test_env=None)

    trainer()