def create_env_fun(game_name, sticky_actions=True): del game_name, sticky_actions batch_env = batch_env_fn(in_graph=False) env = FlatBatchEnv(batch_env) env = TimeLimit(env, max_episode_steps=time_limit) env = ResizeObservation(env) # pylint: disable=redefined-variable-type env = GameOverOnDone(env) return env
def __init__(self, real_env, world_model_dir, hparams, random_starts, setable_initial_frames=False): """Init. Args: real_env: gym environment. world_model_dir: path to world model checkpoint directory. hparams: hparams for rlmb pipeline. random_starts: if restart world model from random frames, or only from initial ones (from beginning of episodes). Valid only when `setable_initial_fames` set to False. setable_initial_frames: if True, initial_frames for world model should be set by `add_to_initial_stack`. """ self._setable_initial_frames = setable_initial_frames if self._setable_initial_frames: real_obs_shape = real_env.observation_space.shape shape = (1, hparams.frame_stack_size) + real_obs_shape self._initial_frames = np.zeros(shape=shape, dtype=np.uint8) def initial_frame_chooser(batch_size): assert batch_size == 1 return self._initial_frames else: initial_frame_chooser = rl_utils.make_initial_frame_chooser( real_env, hparams.frame_stack_size, simulation_random_starts=random_starts, simulation_flip_first_random_for_beginning=False) env_fn = make_simulated_env_fn_from_hparams( real_env, hparams, batch_size=1, initial_frame_chooser=initial_frame_chooser, model_dir=world_model_dir, ) env = env_fn(in_graph=False) self.env = FlatBatchEnv(env) self.observation_space = self.env.observation_space self.action_space = self.env.action_space
def make_simulated_gym_env(real_env, world_model_dir, hparams, random_starts): """Gym environment with world model.""" initial_frame_chooser = rl_utils.make_initial_frame_chooser( real_env, hparams.frame_stack_size, simulation_random_starts=random_starts, simulation_flip_first_random_for_beginning=False) env_fn = make_simulated_env_fn_from_hparams( real_env, hparams, batch_size=1, initial_frame_chooser=initial_frame_chooser, model_dir=world_model_dir) env = env_fn(in_graph=False) flat_env = FlatBatchEnv(env) return flat_env
def main(_): # gym.logger.set_level(gym.logger.DEBUG) hparams = registry.hparams(FLAGS.loop_hparams_set) hparams.parse(FLAGS.loop_hparams) # Not important for experiments past 2018 if "wm_policy_param_sharing" not in hparams.values().keys(): hparams.add_hparam("wm_policy_param_sharing", False) directories = player_utils.infer_paths(output_dir=FLAGS.output_dir, world_model=FLAGS.wm_dir, policy=FLAGS.policy_dir, data=FLAGS.episodes_data_dir) epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch) if FLAGS.simulated_env: env = player_utils.load_data_and_make_simulated_env( directories["data"], directories["world_model"], hparams, which_epoch_data=epoch) else: env = player_utils.setup_and_load_epoch(hparams, data_dir=directories["data"], which_epoch_data=epoch) env = FlatBatchEnv(env) env = PlayerEnvWrapper(env) # pylint: disable=redefined-variable-type env = player_utils.wrap_with_monitor(env, FLAGS.video_dir) if FLAGS.dry_run: for _ in range(5): env.reset() for i in range(50): env.step(i % 3) env.step(PlayerEnvWrapper.RESET_ACTION) # reset return play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)
def make_real_env(): env = player_utils.setup_and_load_epoch(hparams, data_dir=directories["data"], which_epoch_data=None) env = FlatBatchEnv(env) # pylint: disable=redefined-variable-type return env