Example #1
0
 def _next_state_reward(self, state: State, action: Action) -> StateReward:
     x, y = state.value
     if state.value in self.walls or state.value == self.goal:
         return StateReward(State((x, y), state.is_terminal), 0.0)
     if action.value == 0:
         to_pos, reward, is_end = self._transit((x, y), (x + 1, y))
     elif action.value == 1:
         to_pos, reward, is_end = self._transit((x, y), (x, y + 1))
     elif action.value == 2:
         to_pos, reward, is_end = self._transit((x, y), (x - 1, y))
     else:
         to_pos, reward, is_end = self._transit((x, y), (x, y - 1))
     return StateReward(State(to_pos, is_end), reward)
Example #2
0
    def step(self, policy: RLPolicy):
        a_dist = policy(self.current_state)
        a = a_dist.sample()
        if isinstance(a, list):
            a = a[0]
        s_dist = self(self.current_state, a)
        srs = []
        probs = []
        for s, rp in s_dist.items():
            srs.append(StateReward(s, rp.reward))
            probs.append(rp.prob)
        sr = random.choices(srs, weights=probs)[0]

        last_state = self.current_state
        noop = sr.state == self.current_state
        if not noop:
            self.current_state = sr.state
        self._steps_taken += 1

        status = Transition.Status.NORMAL
        if 0 < self._max_horizon <= self._steps_taken or self.current_state.is_terminal:
            status = Transition.Status.TERMINATED
        elif noop:
            status = Transition.Status.NOOP
        return Transition(
            last_state=last_state,
            action=a,
            action_prob=a_dist[a],
            state=self.current_state,
            reward=sr.reward,
            status=status,
        )
Example #3
0
 def _next_state_reward(self, state: State, action: Action) -> StateReward:
     value = state.value
     assert isinstance(value,
                       tuple), f"got type {type(value)} instead of tuple"
     (x, y) = value
     assert isinstance(x, int) and isinstance(
         y, int), "Gridworld expects states to be Tuple[int, int]"
     if state.value in self.walls or state.value == self.goal:
         return StateReward(State((x, y), state.is_terminal), 0.0)
     if action.value == 0:
         to_pos, reward, is_end = self._transit((x, y), (x + 1, y))
     elif action.value == 1:
         to_pos, reward, is_end = self._transit((x, y), (x, y + 1))
     elif action.value == 2:
         to_pos, reward, is_end = self._transit((x, y), (x - 1, y))
     else:
         to_pos, reward, is_end = self._transit((x, y), (x, y - 1))
     return StateReward(State(to_pos, is_end), reward)