import gym
import pybullet_envs
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
from stable_baselines3.common.vec_env import DummyVecEnv
import common.utils as utils

utils.folder = "exp5"
venv = DummyVecEnv([utils.make_env(template=utils.template(400), robot_body=400, wrapper=None)])
venv = VecFrameStack(venv, 4)
obs = venv.reset()
print(obs.shape)
Beispiel #2
0
        gym_interface.make_env(
            rank=i,
            seed=common.seed,
            wrappers=default_wrapper,
            render=args.render,
            robot_body=args.train_bodies[i % len(args.train_bodies)])
        for i in range(args.num_venvs)
    ])

    normalize_kwargs = {}
    if args.vec_normalize:
        normalize_kwargs["gamma"] = hyperparams["gamma"]
        venv = VecNormalize(venv, **normalize_kwargs)

    if args.stack_frames > 1:
        venv = VecFrameStack(venv, args.stack_frames)

    keys_remove = ["normalize", "n_envs", "n_timesteps", "policy"]
    for key in keys_remove:
        if key in hyperparams:
            del hyperparams[key]

    print("Making eval environments...")
    all_callbacks = []
    for test_body in args.test_bodies:
        body_info = 0
        eval_venv = DummyVecEnv([
            gym_interface.make_env(rank=0,
                                   seed=common.seed + 1,
                                   wrappers=default_wrapper,
                                   render=False,
from utils.wrappers import DepthWrapper

log_dir = "./data/reach_depth_sb_log"
save_path = "./data/reach_depth_sb"
best_save_path = "./data/reach_depth_sb_best"

os.makedirs(log_dir, exist_ok=True)


def env_fn():
    return DepthWrapper(
        TimeLimit(gym.make("PepperReachDepth-v0", gui=False, dense=True),
                  max_episode_steps=100))


env = VecFrameStack(DummyVecEnv([env_fn]), n_stack=8, channels_order="first")

eval_env = VecFrameStack(DummyVecEnv([env_fn]),
                         n_stack=8,
                         channels_order="first")

policy_kwargs = dict(
    activation_fn=th.nn.ReLU,
    net_arch=[64, 64, 64],
    normalize_images=False,
    features_extractor_class=StackCNN,
    features_extractor_kwargs=dict(features_dim=16,
                                   linear_dim=16,
                                   n_channels=1),
)