コード例 #1
0
ファイル: player_utils.py プロジェクト: zjms/tensor2tensor
def load_data_and_make_simulated_env(data_dir,
                                     wm_dir,
                                     hparams,
                                     which_epoch_data="last",
                                     random_starts=True):
    hparams = copy.deepcopy(hparams)
    t2t_env = T2TGymEnv.setup_and_load_epoch(hparams,
                                             data_dir=data_dir,
                                             which_epoch_data=which_epoch_data)
    return make_simulated_gym_env(t2t_env,
                                  world_model_dir=wm_dir,
                                  hparams=hparams,
                                  random_starts=random_starts)
コード例 #2
0
def main(_):
    # gym.logger.set_level(gym.logger.DEBUG)
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    # Not important for experiments past 2018
    if "wm_policy_param_sharing" not in hparams.values().keys():
        hparams.add_hparam("wm_policy_param_sharing", False)
    directories = player_utils.infer_paths(output_dir=FLAGS.output_dir,
                                           world_model=FLAGS.wm_dir,
                                           policy=FLAGS.policy_dir,
                                           data=FLAGS.episodes_data_dir)
    epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch)

    if FLAGS.simulated_env:
        env = player_utils.load_data_and_make_simulated_env(
            directories["data"],
            directories["world_model"],
            hparams,
            which_epoch_data=epoch)
    else:
        env = T2TGymEnv.setup_and_load_epoch(hparams,
                                             data_dir=directories["data"],
                                             which_epoch_data=epoch)
        env = FlatBatchEnv(env)

    env = PlayerEnvWrapper(env)  # pylint: disable=redefined-variable-type

    env = player_utils.wrap_with_monitor(env, FLAGS.video_dir)

    if FLAGS.dry_run:
        for _ in range(5):
            env.reset()
            for i in range(50):
                env.step(i % 3)
            env.step(PlayerEnvWrapper.RESET_ACTION)  # reset
        return

    play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)