def test_ttyrec_every(self): path = pathlib.Path(".") env = gym.make("NetHackChallenge-v0", save_ttyrec_every=2, savedir=str(path)) pid = os.getpid() for episode in range(10): env.reset() for c in [ord(" "), ord(" "), ord("<"), ord("y")]: _, _, done, *_ = env.step(env.actions.index(c)) assert done if episode % 2 != 0: continue contents = set(str(p) for p in path.iterdir()) # `contents` includes xlogfile and ttyrecs. assert len(contents) - 1 == episode // 2 + 1 assert ( "nle.%i.%i.ttyrec%i.bz2" % (pid, episode, nethack.TTYREC_VERSION) in contents ) assert "nle.%i.xlogfile" % pid in contents with open("nle.%i.xlogfile" % pid, "r") as f: entries = f.readlines() assert len(entries) == 10
def test_wizkit_file(self): env = gym.make("NetHack-v0", wizard=True) req_items = ["meatball", "apple"] env.reset(wizkit_items=req_items) # TODO: Test inventory here. env.reset(wizkit_items=req_items) del env
def test_ttyrec_every(self): path = pathlib.Path(".") env = gym.make("NetHackScore-v0", save_ttyrec_every=2, savedir=str(path)) for episode in range(10): env.reset() if episode % 2 != 0: continue contents = set(str(p) for p in path.iterdir()) assert len(contents) == episode // 2 + 1 assert "nle.%i.%i.ttyrec.bz2" % (os.getpid(), episode) in contents
def test_render_ansi(self, env_name, rollout_len): env = gym.make(env_name) env.reset() for _ in range(rollout_len): action = env.action_space.sample() _, _, done, _ = env.step(action) if done: env.reset() output = env.render(mode="ansi") assert isinstance(output, str) assert len(output.replace("\n", "")) == np.prod(nle.env.DUNGEON_SHAPE)
def test_kick_and_quit(self, env): env.reset() kick = env.actions.index(nethack.Command.KICK) obs, reward, done, _ = env.step(kick) assert b"In what direction? " in bytes(obs["message"]) env.step(nethack.MiscAction.MORE) # Hack to quit. env.nethack.step(nethack.M("q")) obs, reward, done, _ = env.step(env.actions.index(ord("y"))) assert done assert reward == 0.0
def test_inventory(self, env_name): env = gym.make( env_name, observation_keys=( "chars", "inv_glyphs", "inv_strs", "inv_letters", "inv_oclasses", ), ) obs = env.reset() found = dict(spellbook=0, apple=0) for line in obs["inv_strs"]: if np.all(line == 0): break for key in found: if key in line.tobytes().decode("utf-8"): found[key] += 1 for key, count in found.items(): assert key == key and count > 0 assert "inv_strs" in obs index = 0 if obs["inv_letters"][index] != ord("a"): # We autopickedup some gold. assert obs["inv_letters"][index] == ord("$") assert obs["inv_oclasses"][index] == nethack.COIN_CLASS index = 1 assert obs["inv_letters"][index] == ord("a") assert obs["inv_oclasses"][index] == nethack.ARMOR_CLASS
def test_wizkit_file(self): env = gym.make("NetHack-v0", wizard=True) req_items = ["meatball", "apple"] env.reset(wizkit_items=req_items) path_to_wizkit = os.path.join(env.env._vardir, nethack.nethack.WIZKIT_FNAME) # test file exists os.path.exists(path_to_wizkit) # test that file content corresponds to what you requested with open(path_to_wizkit, "r") as f: for item, line in zip(req_items, f): assert item == line.strip() env.reset(wizkit_items=req_items) with open(path_to_wizkit, "r") as f: lines = f.readlines() assert len(lines) == len(req_items) del env
def main(args): env = make_venv(args) env.reset() start_time = time.time() for i in range(args.num_steps): env.step([np.random.randint(8)] * args.num_env) if (i - 1) % 200 == 0: env.reset() total_time_multi = time.time() - start_time print( "Took {:.2f}s with subproc={} on {} steps on {} envs - {:.2f} FPS".format( total_time_multi, args.subproc, args.num_steps, args.num_env, args.num_steps / total_time_multi, ) )
def test_chars_colors_specials(self, env_name): env = gym.make(env_name, observation_keys=("chars", "colors", "specials", "status")) obs = env.reset() assert "specials" in obs x, y = obs["status"][:2] # That's where you're @. assert obs["chars"][y, x] == ord("@") # You're bright (4th bit, 8) white (7), too. assert obs["colors"][y, x] == 8 ^ 7
def test_meatball_exists(self): """Test loading stuff via wizkit""" env = gym.make("NetHack-v0", wizard=True) found = dict(meatball=0) obs = env.reset(wizkit_items=list(found.keys())) for line in obs["inv_strs"]: if np.all(line == 0): break for key in found: if key in line.tobytes().decode("utf-8"): found[key] += 1 for key, count in found.items(): assert key == key and count > 0 del env
def rollout_env(env, max_rollout_len): """Produces a rollout and asserts step outputs. Does *not* assume that the environment has already been reset. """ obs = env.reset() assert env.observation_space.contains(obs) step = 0 while True: a = env.action_space.sample() obs, reward, done, info = env.step(a) assert env.observation_space.contains(obs) assert isinstance(reward, float) assert isinstance(done, bool) assert isinstance(info, dict) if done or step >= max_rollout_len: break env.close()
def rollout_env(env, max_rollout_len): """Produces a rollout and asserts step outputs. Returns final reward. Does not assume that the environment has already been reset. """ obs = env.reset() assert env.observation_space.contains(obs) for _ in range(max_rollout_len): a = env.action_space.sample() obs, reward, done, info = env.step(a) assert env.observation_space.contains(obs) assert isinstance(reward, float) assert isinstance(done, bool) assert isinstance(info, dict) if done: break env.close() return reward
def test_final_reward(self, env): obs = env.reset() for _ in range(100): obs, reward, done, info = env.step(env.action_space.sample()) if done: break if done: assert reward == 0.0 return # Hopefully, we got some positive reward by now. # Get out of any menu / yn_function. env.step(env.actions.index(ord("\r"))) # Hack to quit. env.nethack.step(nethack.M("q")) _, reward, done, _ = env.step(env.actions.index(ord("y"))) assert done assert reward == 0.0
def test_reset(self, env_name): """Tests default initialization given standard env specs.""" env = gym.make(env_name) obs = env.reset() assert env.observation_space.contains(obs)
def test_wizkit_no_wizard_mode(self): env = gym.make("NetHack-v0", wizard=False) with pytest.raises(ValueError) as e_info: env.reset(wizkit_items=["meatball"]) assert e_info.value.args[0] == "Set wizard=True to use the wizkit option."