parser.set_defaults(test_interval=500)

#parser.set_defaults(horizon=1024)
#parser.set_defaults(batch_size=512)
parser.set_defaults(gpu=-1)
parser.set_defaults(max_steps=100000000)
parser.set_defaults(n_warmup=0)
#parser.set_defaults(enable_gae=True)
args = parser.parse_args()

env = ArmEnvironment(static_goal=True,slow_step=slow_step)
test_env = ArmEnvironment(static_goal=True,slow_step=slow_step)

policy = PPO(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
        is_discrete=is_discrete(env.action_space),
        max_action=None if is_discrete(
            env.action_space) else env.action_space.high[0],
        batch_size=args.batch_size,
        actor_units=(64, 64),
        critic_units=(64, 64),
        n_epoch=10,
        lr_actor=3e-4,
        lr_critic=3e-4,
        hidden_activation_actor="tanh",
        hidden_activation_critic="tanh",
        discount=0.99,
        lam=0.95,
        entropy_coef=0.001,
        horizon=args.horizon,
示例#2
0
文件: MFBox.py 项目: Chenaah/RAMCO
def run(parser):

    args = parser.parse_args()

    if args.gpu < 0:
        tf.config.experimental.set_visible_devices([], 'GPU')
    else:
        physical_devices = tf.config.list_physical_devices('GPU')
        tf.config.set_visible_devices(physical_devices[args.gpu], 'GPU')
        tf.config.experimental.set_virtual_device_configuration(
            physical_devices[args.gpu], [
                tf.config.experimental.VirtualDeviceConfiguration(
                    memory_limit=1024 * 3)
            ])

    if args.env == 200:
        envname = 'ScratchItchPR2X'
    elif args.env == 201:
        envname = 'DressingPR2X'
    elif args.env == 202:
        envname = 'BedBathingPR2X'

    logdir = f'MFBox_Assistive'
    if args.SAC:
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'SAC on {envname}')
    elif args.PPO:
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'PPO on {envname}')
    elif args.TD3:
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'TD3 on {envname}')
    elif args.DEBUG:
        logdir = f'DEBUG_Assistive'
        wandb.init(config=vars(args),
                   project="Assistive Gym",
                   name=f'DEBUG on {envname}')
    else:
        print('PLEASE INDICATE THE ALGORITHM !!')

    if not os.path.exists(logdir):
        os.makedirs(logdir)
    parser.set_defaults(logdir=logdir)
    args = parser.parse_args()

    env = gym.make(f'{envname}-v0')
    #test_env = Monitor(env,logdir,force=True)
    test_env = gym.make(f'{envname}-v0')

    if args.SAC:

        policy = SAC(state_shape=env.observation_space.shape,
                     action_dim=env.action_space.high.size,
                     gpu=args.gpu,
                     memory_capacity=args.memory_capacity,
                     max_action=env.action_space.high[0],
                     batch_size=args.batch_size,
                     n_warmup=args.n_warmup,
                     alpha=args.alpha,
                     auto_alpha=args.auto_alpha)
        trainer = Trainer(policy, env, args, test_env=test_env)

    elif args.PPO:
        policy = PPO(state_shape=env.observation_space.shape,
                     action_dim=get_act_dim(env.action_space),
                     is_discrete=is_discrete(env.action_space),
                     max_action=None if is_discrete(env.action_space) else
                     env.action_space.high[0],
                     batch_size=args.batch_size,
                     actor_units=(64, 64),
                     critic_units=(64, 64),
                     n_epoch=10,
                     lr_actor=3e-4,
                     lr_critic=3e-4,
                     hidden_activation_actor="tanh",
                     hidden_activation_critic="tanh",
                     discount=0.99,
                     lam=0.95,
                     entropy_coef=0.,
                     horizon=args.horizon,
                     normalize_adv=args.normalize_adv,
                     enable_gae=args.enable_gae,
                     gpu=args.gpu)
        trainer = OnPolicyTrainer(policy, env, args, test_env=test_env)

    elif args.TD3:
        policy = TD3(state_shape=env.observation_space.shape,
                     action_dim=env.action_space.high.size,
                     gpu=args.gpu,
                     memory_capacity=args.memory_capacity,
                     max_action=env.action_space.high[0],
                     batch_size=args.batch_size,
                     n_warmup=args.n_warmup)
        trainer = Trainer(policy, env, args, test_env=test_env)

    elif args.DEBUG:

        policy = SAC(state_shape=env.observation_space.shape,
                     action_dim=env.action_space.high.size,
                     gpu=args.gpu,
                     memory_capacity=args.memory_capacity,
                     max_action=env.action_space.high[0],
                     batch_size=args.batch_size,
                     n_warmup=100,
                     alpha=args.alpha,
                     auto_alpha=args.auto_alpha)
        parser.set_defaults(test_interval=200)
        args = parser.parse_args()

        trainer = Trainer(policy, env, args, test_env=None)

    trainer()