class StateTracking(textworld.core.Wrapper): """ Wrapper that enables state tracking for Inform7 games generated by TextWorld. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._gamefile = None self._game = None self._inform7 = None self._last_action = None self._previous_winning_policy = None self._current_winning_policy = None self._moves = None self._game_progression = None @property def tracking(self): return (self.infos.intermediate_reward or self.infos.policy_commands or self.infos.admissible_commands or self.infos.facts or self.infos.last_action) def load(self, gamefile: str) -> None: self._wrapped_env.load(gamefile) self._gamefile = os.path.splitext(gamefile)[0] + ".json" try: self._game = self._wrapped_env._game except AttributeError: if not os.path.isfile(self._gamefile): raise MissingGameInfosError(self) self._game = Game.load(self._gamefile) self._game_progression = None self._inform7 = Inform7Game(self._game) def _gather_infos(self): self.state["_game_progression"] = self._game_progression self.state["_facts"] = list(self._game_progression.state.facts) self.state["won"] = '*** The End ***' in self.state["feedback"] self.state["lost"] = '*** You lost! ***' in self.state["feedback"] self.state["_winning_policy"] = self._current_winning_policy if self.infos.policy_commands: self.state["policy_commands"] = [] if self._current_winning_policy is not None: self.state["policy_commands"] = self._inform7.gen_commands_from_actions(self._current_winning_policy) if self.infos.intermediate_reward: self.state["intermediate_reward"] = 0 if self.state["won"]: # The last action led to winning the game. self.state["intermediate_reward"] = 1 elif self.state["lost"]: # The last action led to losing the game. self.state["intermediate_reward"] = -1 elif self._previous_winning_policy is None: self.state["intermediate_reward"] = 0 else: diff = len(self._previous_winning_policy) - len(self._current_winning_policy) self.state["intermediate_reward"] = int(diff > 0) - int(diff < 0) # Sign function. if self.infos.facts: self.state["facts"] = list(map(self._inform7.get_human_readable_fact, self.state["_facts"])) self.state["_last_action"] = self._last_action if self.infos.last_action and self._last_action is not None: self.state["last_action"] = self._inform7.get_human_readable_action(self._last_action) self.state["_valid_actions"] = self._game_progression.valid_actions if self.infos.admissible_commands: all_valid_commands = self._inform7.gen_commands_from_actions(self._game_progression.valid_actions) # To guarantee the order from one execution to another, we sort the commands. # Remove any potential duplicate commands (they would lead to the same result anyway). self.state["admissible_commands"] = sorted(set(all_valid_commands)) if self.infos.moves: self.state["moves"] = self._moves def _send(self, command: str) -> str: """ Send a command to the game without affecting the Environment's state. """ return self.unwrapped._send(command) def reset(self): self.state = self._wrapped_env.reset() if not self.tracking: return self.state # State tracking not needed. self._send('tw-trace-actions') # Turn on print for Inform7 action events. track_quests = (self.infos.intermediate_reward or self.infos.policy_commands) self._game_progression = GameProgression(self._game, track_quests=track_quests) self._last_action = None self._previous_winning_policy = None self._current_winning_policy = self._game_progression.winning_policy self._moves = 0 self._gather_infos() return self.state def step(self, command: str): self.state, score, done = self._wrapped_env.step(command) if not self.tracking: return self.state, score, done # State tracking not needed. # Detect what events just happened in the game. i7_events, self.state["feedback"] = _detect_i7_events_debug_tags(self.state["feedback"]) if str2bool(os.environ.get("TEXTWORLD_DEBUG", False)): print("[DEBUG] Detected Inform7 events:\n{}\n".format(i7_events)) self._previous_winning_policy = self._current_winning_policy for i7_event in i7_events: valid_actions = self._game_progression.valid_actions self._last_action = self._inform7.detect_action(i7_event, valid_actions) if self._last_action is not None: # An action that affects the state of the game. self._game_progression.update(self._last_action) self._current_winning_policy = self._game_progression.winning_policy self._moves += 1 self._gather_infos() self.state["done"] = self.state["won"] or self.state["lost"] return self.state, score, self.state["done"] def copy(self) -> "StateTracking": """ Returns a copy this wrapper. """ env = StateTracking() env._wrapped_env = self._wrapped_env.copy() env._gamefile = self._gamefile env._game = self._game # Reference env._inform7 = self._inform7 # Reference env._last_action = self._last_action env._moves = self._moves if self._previous_winning_policy is not None: env._previous_winning_policy = list(self._previous_winning_policy) if self._current_winning_policy is not None: env._current_winning_policy = list(self._current_winning_policy) if self._game_progression is not None: env._game_progression = self._game_progression.copy() return env
class TextWorldEnv(textworld.Environment): """ Environment for playing games by TextWorld. """ def __init__(self, infos: Optional[EnvInfos] = None) -> None: """ Arguments: infos: Information to be included in the game state. By default, only the game's narrative is included. """ super().__init__(infos) self._gamefile = None self._game = None self._inform7 = None self._last_action = None self._prev_state = None self._previous_winning_policy = None self._current_winning_policy = None self._moves = None self._game_progression = None def load(self, path: str) -> None: self._gamefile = path self._game = textworld.Game.load(self._gamefile) self._game_progression = None self._inform7 = Inform7Game(self._game) def _gather_infos(self): self.state["game"] = self._game self.state["command_templates"] = self._game.command_templates self.state["verbs"] = self._game.verbs self.state["entities"] = self._game.entity_names self.state["objective"] = self._game.objective self.state["max_score"] = self._game.max_score for k, v in self._game.metadata.items(): self.state["extra.{}".format(k)] = v self.state["_game_progression"] = self._game_progression self.state["_facts"] = list(self._game_progression.state.facts) self.state["won"] = self._game_progression.completed self.state["lost"] = self._game_progression.failed self.state["_winning_policy"] = self._current_winning_policy if self.infos.policy_commands: self.state["policy_commands"] = [] if self._game_progression.winning_policy is not None: self.state["policy_commands"] = self._inform7.gen_commands_from_actions(self._current_winning_policy) if self.infos.intermediate_reward: self.state["intermediate_reward"] = 0 if self.state["won"]: # The last action led to winning the game. self.state["intermediate_reward"] = 1 elif self.state["lost"]: # The last action led to losing the game. self.state["intermediate_reward"] = -1 elif self._previous_winning_policy is None: self.state["intermediate_reward"] = 0 else: diff = len(self._previous_winning_policy) - len(self._current_winning_policy) self.state["intermediate_reward"] = int(diff > 0) - int(diff < 0) # Sign function. if self.infos.facts: self.state["facts"] = list(map(self._inform7.get_human_readable_fact, self.state["_facts"])) self.state["last_action"] = None self.state["_last_action"] = self._last_action if self.infos.last_action and self._last_action is not None: self.state["last_action"] = self._inform7.get_human_readable_action(self._last_action) self.state["_valid_actions"] = self._game_progression.valid_actions self.state["_valid_commands"] = self._inform7.gen_commands_from_actions(self._game_progression.valid_actions) # To guarantee the order from one execution to another, we sort the commands. # Remove any potential duplicate commands (they would lead to the same result anyway). self.state["admissible_commands"] = sorted(set(self.state["_valid_commands"])) if self.infos.moves: self.state["moves"] = self._moves def reset(self): self._prev_state = None self.state = GameState() track_quests = (self.infos.intermediate_reward or self.infos.policy_commands) self._game_progression = GameProgression(self._game, track_quests=track_quests) self._last_action = None self._previous_winning_policy = None self._current_winning_policy = self._game_progression.winning_policy self._moves = 0 self.state.raw = DEFAULT_OBSERVATION self.state.feedback = DEFAULT_OBSERVATION self._gather_infos() return self.state def step(self, command: str): command = command.strip() self._prev_state = self.state self.state = GameState() self.state.last_command = command self.state.raw = DEFAULT_OBSERVATION self.state.feedback = DEFAULT_OBSERVATION self._previous_winning_policy = self._current_winning_policy self._last_action = None try: # Find the action corresponding to the command. idx = self._prev_state["_valid_commands"].index(command) self._last_action = self._game_progression.valid_actions[idx] # An action that affects the state of the game. self._game_progression.update(self._last_action) self._current_winning_policy = self._game_progression.winning_policy self._moves += 1 except ValueError: self.state.feedback = "Invalid command." pass # We assume nothing happened in the game. self._gather_infos() self.state["score"] = self._game_progression self.state["done"] = self.state["won"] or self.state["lost"] return self.state, self.state["score"], self.state["done"] def copy(self) -> "TextWorldEnv": """ Return a copy of this environment. It is safe to call `step` and `reset` on the copied environment. .. warning:: The `Game` and `Inform7Game` private objects are *soft* copies. """ env = TextWorldEnv() # Copy core Environment's attributes. env.state = self.state.copy() env.infos = self.infos.copy() env._gamefile = self._gamefile env._game = self._game # Reference env._inform7 = self._inform7 # Reference env._prev_state = self._prev_state.copy() if self._prev_state is not None else None env._last_action = self._last_action env._moves = self._moves if self._previous_winning_policy is not None: env._previous_winning_policy = tuple(self._previous_winning_policy) if self._current_winning_policy is not None: env._current_winning_policy = tuple(self._current_winning_policy) if self._game_progression is not None: env._game_progression = self._game_progression.copy() return env