def make_env():
    xmagical.register_envs()
    embodiment_name = FLAGS.embodiment.capitalize()
    env = gym.make(f"SweepToTop-{embodiment_name}-State-Allo-TestLayout-v0")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = wrapper_from_config(FLAGS.config, env, device)
    return env
Exemplo n.º 2
0
def make_env(
    env_name,
    seed,
    save_dir = None,
    add_episode_monitor = True,
    action_repeat = 1,
    frame_stack = 1,
):
  """Env factory with wrapping.

  Args:
    env_name: The name of the environment.
    seed: The RNG seed.
    save_dir: Specifiy a save directory to wrap with `VideoRecorder`.
    add_episode_monitor: Set to True to wrap with `EpisodeMonitor`.
    action_repeat: A value > 1 will wrap with `ActionRepeat`.
    frame_stack: A value > 1 will wrap with `FrameStack`.

  Returns:
    gym.Env object.
  """
  # Check if the env is in x-magical.
  xmagical.register_envs()
  if env_name in xmagical.ALL_REGISTERED_ENVS:
    env = gym.make(env_name)
  else:
    raise ValueError(f"{env_name} is not a valid environment name.")

  if add_episode_monitor:
    env = wrappers.EpisodeMonitor(env)
  if action_repeat > 1:
    env = wrappers.ActionRepeat(env, action_repeat)
  env = RescaleAction(env, -1.0, 1.0)
  if save_dir is not None:
    env = wrappers.VideoRecorder(env, save_dir=save_dir)
  if frame_stack > 1:
    env = wrappers.FrameStack(env, frame_stack)

  # Seed.
  env.seed(seed)
  env.action_space.seed(seed)
  env.observation_space.seed(seed)

  return env
Exemplo n.º 3
0
def main(_):
    register_envs()
    env = gym.make(FLAGS.env_name)
    viewer = KeyboardEnvInteractor(action_dim=env.action_space.shape[0])

    env.reset()
    obs = env.render("rgb_array")
    viewer.imshow(obs)

    i = [0]

    def step(action):
        obs, rew, done, info = env.step(action)
        if obs.ndim != 3:
            obs = env.render("rgb_array")
        if done and FLAGS.exit_on_done:
            return
        if i[0] % 100 == 0:
            print(f"Done, score {info['eval_score']:.2f}/1.00")
        i[0] += 1
        return obs

    viewer.run_loop(step)
Exemplo n.º 4
0
throw an error (and subsequently fail all tests) if the dependency is not
installed.

References:
    [1]: https://github.com/HumanCompatibleAI/seals
    [2]: https://github.com/qxcv/magical
"""

import gym
import numpy as np
import pytest

import xmagical

# Register environments to fill ALL_REGISTERED_ENVS.
xmagical.register_envs()
# Keep this small to make test time reasonable.
N_ROLLOUTS = 2


def make_env_fixture(skip_fn):
    def f(env_name: str):
        env = None
        try:
            env = gym.make(env_name)
            yield env
        except Exception as e:
            raise e
        finally:
            if env is not None:
                env.close()