Exemplo n.º 1
0
    def test_ttyrec_every(self):
        path = pathlib.Path(".")
        env = gym.make("NetHackChallenge-v0", save_ttyrec_every=2, savedir=str(path))
        pid = os.getpid()
        for episode in range(10):
            env.reset()
            for c in [ord(" "), ord(" "), ord("<"), ord("y")]:
                _, _, done, *_ = env.step(env.actions.index(c))
            assert done

            if episode % 2 != 0:
                continue
            contents = set(str(p) for p in path.iterdir())
            # `contents` includes xlogfile and ttyrecs.
            assert len(contents) - 1 == episode // 2 + 1
            assert (
                "nle.%i.%i.ttyrec%i.bz2" % (pid, episode, nethack.TTYREC_VERSION)
                in contents
            )
            assert "nle.%i.xlogfile" % pid in contents

        with open("nle.%i.xlogfile" % pid, "r") as f:
            entries = f.readlines()

        assert len(entries) == 10
Exemplo n.º 2
0
    def test_wizkit_file(self):
        env = gym.make("NetHack-v0", wizard=True)
        req_items = ["meatball", "apple"]
        env.reset(wizkit_items=req_items)

        # TODO: Test inventory here.
        env.reset(wizkit_items=req_items)
        del env
Exemplo n.º 3
0
 def test_ttyrec_every(self):
     path = pathlib.Path(".")
     env = gym.make("NetHackScore-v0", save_ttyrec_every=2, savedir=str(path))
     for episode in range(10):
         env.reset()
         if episode % 2 != 0:
             continue
         contents = set(str(p) for p in path.iterdir())
         assert len(contents) == episode // 2 + 1
         assert "nle.%i.%i.ttyrec.bz2" % (os.getpid(), episode) in contents
Exemplo n.º 4
0
 def test_render_ansi(self, env_name, rollout_len):
     env = gym.make(env_name)
     env.reset()
     for _ in range(rollout_len):
         action = env.action_space.sample()
         _, _, done, _ = env.step(action)
         if done:
             env.reset()
         output = env.render(mode="ansi")
         assert isinstance(output, str)
         assert len(output.replace("\n", "")) == np.prod(nle.env.DUNGEON_SHAPE)
Exemplo n.º 5
0
    def test_kick_and_quit(self, env):
        env.reset()
        kick = env.actions.index(nethack.Command.KICK)
        obs, reward, done, _ = env.step(kick)
        assert b"In what direction? " in bytes(obs["message"])
        env.step(nethack.MiscAction.MORE)

        # Hack to quit.
        env.nethack.step(nethack.M("q"))
        obs, reward, done, _ = env.step(env.actions.index(ord("y")))

        assert done
        assert reward == 0.0
Exemplo n.º 6
0
    def test_inventory(self, env_name):
        env = gym.make(
            env_name,
            observation_keys=(
                "chars",
                "inv_glyphs",
                "inv_strs",
                "inv_letters",
                "inv_oclasses",
            ),
        )
        obs = env.reset()

        found = dict(spellbook=0, apple=0)
        for line in obs["inv_strs"]:
            if np.all(line == 0):
                break
            for key in found:
                if key in line.tobytes().decode("utf-8"):
                    found[key] += 1

        for key, count in found.items():
            assert key == key and count > 0

        assert "inv_strs" in obs

        index = 0
        if obs["inv_letters"][index] != ord("a"):
            # We autopickedup some gold.
            assert obs["inv_letters"][index] == ord("$")
            assert obs["inv_oclasses"][index] == nethack.COIN_CLASS
            index = 1

        assert obs["inv_letters"][index] == ord("a")
        assert obs["inv_oclasses"][index] == nethack.ARMOR_CLASS
Exemplo n.º 7
0
    def test_wizkit_file(self):
        env = gym.make("NetHack-v0", wizard=True)
        req_items = ["meatball", "apple"]
        env.reset(wizkit_items=req_items)
        path_to_wizkit = os.path.join(env.env._vardir, nethack.nethack.WIZKIT_FNAME)

        # test file exists
        os.path.exists(path_to_wizkit)

        # test that file content corresponds to what you requested
        with open(path_to_wizkit, "r") as f:
            for item, line in zip(req_items, f):
                assert item == line.strip()

        env.reset(wizkit_items=req_items)
        with open(path_to_wizkit, "r") as f:
            lines = f.readlines()
        assert len(lines) == len(req_items)
        del env
Exemplo n.º 8
0
def main(args):
    env = make_venv(args)
    env.reset()

    start_time = time.time()
    for i in range(args.num_steps):
        env.step([np.random.randint(8)] * args.num_env)
        if (i - 1) % 200 == 0:
            env.reset()
    total_time_multi = time.time() - start_time

    print(
        "Took {:.2f}s with subproc={} on {} steps on {} envs - {:.2f} FPS".format(
            total_time_multi,
            args.subproc,
            args.num_steps,
            args.num_env,
            args.num_steps / total_time_multi,
        )
    )
Exemplo n.º 9
0
    def test_chars_colors_specials(self, env_name):
        env = gym.make(env_name,
                       observation_keys=("chars", "colors", "specials",
                                         "status"))
        obs = env.reset()

        assert "specials" in obs
        x, y = obs["status"][:2]

        # That's where you're @.
        assert obs["chars"][y, x] == ord("@")

        # You're bright (4th bit, 8) white (7), too.
        assert obs["colors"][y, x] == 8 ^ 7
Exemplo n.º 10
0
 def test_meatball_exists(self):
     """Test loading stuff via wizkit"""
     env = gym.make("NetHack-v0", wizard=True)
     found = dict(meatball=0)
     obs = env.reset(wizkit_items=list(found.keys()))
     for line in obs["inv_strs"]:
         if np.all(line == 0):
             break
         for key in found:
             if key in line.tobytes().decode("utf-8"):
                 found[key] += 1
     for key, count in found.items():
         assert key == key and count > 0
     del env
Exemplo n.º 11
0
def rollout_env(env, max_rollout_len):
    """Produces a rollout and asserts step outputs.

    Does *not* assume that the environment has already been reset.
    """
    obs = env.reset()
    assert env.observation_space.contains(obs)

    step = 0
    while True:
        a = env.action_space.sample()
        obs, reward, done, info = env.step(a)
        assert env.observation_space.contains(obs)
        assert isinstance(reward, float)
        assert isinstance(done, bool)
        assert isinstance(info, dict)
        if done or step >= max_rollout_len:
            break
    env.close()
Exemplo n.º 12
0
def rollout_env(env, max_rollout_len):
    """Produces a rollout and asserts step outputs.

    Returns final reward. Does not assume that the environment has already been
    reset.
    """
    obs = env.reset()
    assert env.observation_space.contains(obs)

    for _ in range(max_rollout_len):
        a = env.action_space.sample()
        obs, reward, done, info = env.step(a)
        assert env.observation_space.contains(obs)
        assert isinstance(reward, float)
        assert isinstance(done, bool)
        assert isinstance(info, dict)
        if done:
            break
    env.close()
    return reward
Exemplo n.º 13
0
    def test_final_reward(self, env):
        obs = env.reset()

        for _ in range(100):
            obs, reward, done, info = env.step(env.action_space.sample())
            if done:
                break

        if done:
            assert reward == 0.0
            return

        # Hopefully, we got some positive reward by now.

        # Get out of any menu / yn_function.
        env.step(env.actions.index(ord("\r")))

        # Hack to quit.
        env.nethack.step(nethack.M("q"))
        _, reward, done, _ = env.step(env.actions.index(ord("y")))

        assert done
        assert reward == 0.0
Exemplo n.º 14
0
 def test_reset(self, env_name):
     """Tests default initialization given standard env specs."""
     env = gym.make(env_name)
     obs = env.reset()
     assert env.observation_space.contains(obs)
Exemplo n.º 15
0
 def test_wizkit_no_wizard_mode(self):
     env = gym.make("NetHack-v0", wizard=False)
     with pytest.raises(ValueError) as e_info:
         env.reset(wizkit_items=["meatball"])
     assert e_info.value.args[0] == "Set wizard=True to use the wizkit option."