예제 #1
0
def main(_):
    # gym.logger.set_level(gym.logger.DEBUG)
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    # Not important for experiments past 2018
    if "wm_policy_param_sharing" not in hparams.values().keys():
        hparams.add_hparam("wm_policy_param_sharing", False)
    directories = player_utils.infer_paths(output_dir=FLAGS.output_dir,
                                           world_model=FLAGS.wm_dir,
                                           policy=FLAGS.policy_dir,
                                           data=FLAGS.episodes_data_dir)
    epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch)

    if FLAGS.simulated_env:
        env = player_utils.load_data_and_make_simulated_env(
            directories["data"],
            directories["world_model"],
            hparams,
            which_epoch_data=epoch)
    else:
        env = player_utils.setup_and_load_epoch(hparams,
                                                data_dir=directories["data"],
                                                which_epoch_data=epoch)
        env = FlatBatchEnv(env)

    env = PlayerEnvWrapper(env)  # pylint: disable=redefined-variable-type

    env = player_utils.wrap_with_monitor(env, FLAGS.video_dir)

    if FLAGS.dry_run:
        for _ in range(5):
            env.reset()
            for i in range(50):
                env.step(i % 3)
            env.step(PlayerEnvWrapper.RESET_ACTION)  # reset
        return

    play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)
예제 #2
0
class SimulatedGymEnv(gym.Env):
    """Gym environment, running with world model.

  Allows passing custom initial frames.

  Examples:
    Setup simulated env from some point of real rollout.
      >>> sim_env = SimulatedGymEnv(setable_initial_frames=True, **kwargs)
      >>> real_env = FlatBatchEnv(T2TGymEnv(...))
      >>> while ...:
      >>>   ob, _, _, _ = real_env.step(action)
      >>>   sim_env.add_to_initial_stack(ob)
      >>> sim_env.reset()
      >>> # Continue sim_env rollout.
  """
    def __init__(self,
                 real_env,
                 world_model_dir,
                 hparams,
                 random_starts,
                 setable_initial_frames=False):
        """Init.

    Args:
       real_env: gym environment.
       world_model_dir: path to world model checkpoint directory.
       hparams: hparams for rlmb pipeline.
       random_starts: if restart world model from random frames, or only
         from initial ones (from beginning of episodes). Valid only when
         `setable_initial_fames` set to False.
       setable_initial_frames: if True, initial_frames for world model should be
         set by `add_to_initial_stack`.
    """

        self._setable_initial_frames = setable_initial_frames

        if self._setable_initial_frames:
            real_obs_shape = real_env.observation_space.shape
            shape = (1, hparams.frame_stack_size) + real_obs_shape
            self._initial_frames = np.zeros(shape=shape, dtype=np.uint8)

            def initial_frame_chooser(batch_size):
                assert batch_size == 1
                return self._initial_frames

        else:
            initial_frame_chooser = rl_utils.make_initial_frame_chooser(
                real_env,
                hparams.frame_stack_size,
                simulation_random_starts=random_starts,
                simulation_flip_first_random_for_beginning=False)
        env_fn = make_simulated_env_fn_from_hparams(
            real_env,
            hparams,
            batch_size=1,
            initial_frame_chooser=initial_frame_chooser,
            model_dir=world_model_dir,
        )

        env = env_fn(in_graph=False)
        self.env = FlatBatchEnv(env)

        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space

    def reset(self):
        return self.env.reset()

    def step(self, action):
        return self.env.step(action)

    def add_to_initial_stack(self, frame):
        """Adds new frame to (initial) frame stack, removes last one."""
        if not self._setable_initial_frames:
            raise ValueError(
                "This instance does not allow to manually set initial frame stack."
            )
        assert_msg = "{}, {}".format(frame.shape,
                                     self._initial_frames.shape[:1])
        assert frame.shape == self._initial_frames.shape[2:], assert_msg
        initial_frames = np.roll(self._initial_frames, shift=-1, axis=1)
        initial_frames[0, -1, ...] = frame
        self._initial_frames = initial_frames