def make_mujoco_env(task, seed, training_num, test_num, obs_norm): """Wrapper function for Mujoco env. If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env. :return: a tuple of (single env, training envs, test envs). """ if envpool is not None: train_envs = env = envpool.make_gym(task, num_envs=training_num, seed=seed) test_envs = envpool.make_gym(task, num_envs=training_num, seed=seed) else: warnings.warn( "Recommend using envpool (pip install envpool) " "to run Mujoco environments more efficiently." ) env = gym.make(task) train_envs = ShmemVectorEnv( [lambda: gym.make(task) for _ in range(training_num)] ) test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) env.seed(seed) train_envs.seed(seed) test_envs.seed(seed) if obs_norm: # obs norm wrapper train_envs = VectorEnvNormObs(train_envs) test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) test_envs.set_obs_rms(train_envs.get_obs_rms()) return env, train_envs, test_envs
def test_venv_wrapper_envpool_gym_reset_return_info(): num_envs = 4 env = VectorEnvNormObs( envpool.make_gym("Ant-v3", num_envs=num_envs, gym_reset_return_info=True)) obs, info = env.reset() assert obs.shape[0] == num_envs for _, v in info.items(): if not isinstance(v, dict): assert v.shape[0] == num_envs
def test_venv_wrapper_envpool(): raw = envpool.make_gym("Ant-v3", num_envs=4) train = VectorEnvNormObs(envpool.make_gym("Ant-v3", num_envs=4)) test = VectorEnvNormObs(envpool.make_gym("Ant-v3", num_envs=4), update_obs_rms=False) test.set_obs_rms(train.get_obs_rms()) actions = [ np.array([raw.action_space.sample() for _ in range(4)]) for i in range(30) ] run_align_norm_obs(raw, train, test, actions)
def test_venv_norm_obs(): sizes = np.array([5, 10, 15, 20]) action = np.array([1, 1, 1, 1]) total_step = 30 action_list = [action] * total_step env_fns = [lambda i=x: MyTestEnv(size=i, array_state=True) for x in sizes] raw = DummyVectorEnv(env_fns) train_env = VectorEnvNormObs(DummyVectorEnv(env_fns)) print(train_env.observation_space) test_env = VectorEnvNormObs(DummyVectorEnv(env_fns), update_obs_rms=False) test_env.set_obs_rms(train_env.get_obs_rms()) run_align_norm_obs(raw, train_env, test_env, action_list)
def test_td3_bc(): args = get_args() env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] # float print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) args.state_dim = args.state_shape[0] args.action_dim = args.action_shape[0] print("Max_action", args.max_action) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)]) if args.norm_obs: test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model # actor network net_a = Net( args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) actor = Actor( net_a, action_shape=args.action_shape, max_action=args.max_action, device=args.device, ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) # critic network net_c1 = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3BCPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, alpha=args.alpha, estimation_step=args.n_step, action_space=env.action_space, ) # load a previous policy if args.resume_path: policy.load_state_dict( torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector test_collector = Collector(policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "td3_bc" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger if args.logger == "wandb": logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: # wandb logger.load(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch(): if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict( torch.load(args.resume_path, map_location=torch.device("cpu"))) policy.eval() collector = Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) if args.norm_obs: replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer( replay_buffer) test_envs.set_obs_rms(obs_rms) # trainer result = offline_trainer( policy, replay_buffer, test_collector, args.epoch, args.step_per_epoch, args.test_num, args.batch_size, save_best_fn=save_best_fn, logger=logger, ) pprint.pprint(result) else: watch() # Let's watch its performance! policy.eval() test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) print( f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}" )