def main(_): # gym.logger.set_level(gym.logger.DEBUG) hparams = registry.hparams(FLAGS.loop_hparams_set) hparams.parse(FLAGS.loop_hparams) # Not important for experiments past 2018 if "wm_policy_param_sharing" not in hparams.values().keys(): hparams.add_hparam("wm_policy_param_sharing", False) directories = player_utils.infer_paths( output_dir=FLAGS.output_dir, world_model=FLAGS.wm_dir, policy=FLAGS.policy_dir, data=FLAGS.episodes_data_dir) if FLAGS.game_from_filenames: hparams.set_hparam( "game", player_utils.infer_game_name_from_filenames(directories["data"]) ) action_meanings = gym.make(full_game_name(hparams.game)).\ unwrapped.get_action_meanings() epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch) def make_real_env(): env = player_utils.setup_and_load_epoch( hparams, data_dir=directories["data"], which_epoch_data=None) env = FlatBatchEnv(env) # pylint: disable=redefined-variable-type return env def make_simulated_env(setable_initial_frames, which_epoch_data): env = player_utils.load_data_and_make_simulated_env( directories["data"], directories["world_model"], hparams, which_epoch_data=which_epoch_data, setable_initial_frames=setable_initial_frames) return env if FLAGS.sim_and_real: sim_env = make_simulated_env( which_epoch_data=None, setable_initial_frames=True) real_env = make_real_env() env = SimAndRealEnvPlayer(real_env, sim_env, action_meanings) else: if FLAGS.simulated_env: env = make_simulated_env( # pylint: disable=redefined-variable-type which_epoch_data=epoch, setable_initial_frames=False) else: env = make_real_env() env = SingleEnvPlayer(env, action_meanings) # pylint: disable=redefined-variable-type env = player_utils.wrap_with_monitor(env, FLAGS.video_dir) if FLAGS.dry_run: env.unwrapped.get_keys_to_action() for _ in range(5): env.reset() for i in range(50): env.step(i % 3) env.step(PlayerEnv.RETURN_DONE_ACTION) # reset return play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)
def main(_): # gym.logger.set_level(gym.logger.DEBUG) hparams = registry.hparams(FLAGS.loop_hparams_set) hparams.parse(FLAGS.loop_hparams) # Not important for experiments past 2018 if "wm_policy_param_sharing" not in hparams.values().keys(): hparams.add_hparam("wm_policy_param_sharing", False) directories = player_utils.infer_paths(output_dir=FLAGS.output_dir, world_model=FLAGS.wm_dir, policy=FLAGS.policy_dir, data=FLAGS.episodes_data_dir) epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch) if FLAGS.simulated_env: env = player_utils.load_data_and_make_simulated_env( directories["data"], directories["world_model"], hparams, which_epoch_data=epoch) else: env = player_utils.setup_and_load_epoch(hparams, data_dir=directories["data"], which_epoch_data=epoch) env = FlatBatchEnv(env) env = PlayerEnvWrapper(env) # pylint: disable=redefined-variable-type env = player_utils.wrap_with_monitor(env, FLAGS.video_dir) if FLAGS.dry_run: for _ in range(5): env.reset() for i in range(50): env.step(i % 3) env.step(PlayerEnvWrapper.RESET_ACTION) # reset return play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)