예제 #1
0
    def __init__(self,
                 combo,
                 num_actions=3,
                 num_states=None,
                 reset_val=0.01,
                 gamma=0.99):
        '''
        Args:
            num_states (int) [optional]: Number of states in the chain.
        '''
        ComboLockMDP.ACTIONS = [str(i) for i in range(1, num_actions + 1)]
        self.num_states = len(combo) if num_states is None else num_states
        self.num_actions = num_actions
        self.combo = combo

        if len(combo) != self.num_states:
            raise ValueError("(simple_rl.ComboLockMDP Error): Combo length (" +
                             str(len(combo)) +
                             ") must be the same as num_states (" +
                             str(self.num_states) + ").")
        elif max(combo) > num_actions:
            raise ValueError(
                "(simple_rl.ComboLockMDP Error): Combo (" + str(combo) +
                ") must only contain values less than or equal to @num_actions ("
                + str(num_actions) + ").")

        MDP.__init__(self,
                     ComboLockMDP.ACTIONS,
                     self._transition_func,
                     self._reward_func,
                     init_state=ChainState(1),
                     gamma=gamma)
예제 #2
0
 def __init__(self, gamma, kappa=0.001):
     MDP.__init__(self,
                  BadChainMDP.ACTIONS,
                  self._transition_func,
                  self._reward_func,
                  init_state=ChainState(1),
                  gamma=gamma)
     self.num_states = 4
     self.kappa = kappa
예제 #3
0
    def _transition_func(self, state, action):
        '''
        Args:
            state (State)
            action (str)

        Returns
            (State)
        '''
        if state.is_terminal():
            # Terminal, done.
            return state
        elif action == "right" and state.num + 1 == self.num_states:
            # Applied right in s2, move to terminal.
            terminal_state = ChainState(self.num_states)
            terminal_state.set_terminal(True)
            return terminal_state
        elif action == "right" and state.num < self.num_states - 1:
            # If in s0 or s1, move to s2.
            return ChainState(state.num + 1)
        elif action == "left" and state.num > 1:
            # If in s1, or s2, move left.
            return ChainState(state.num - 1)
        else:
            # Otherwise, stay in the same state.
            return state
예제 #4
0
 def __init__(self, num_states=5, reset_val=0.01, gamma=0.99):
     '''
     Args:
         num_states (int) [optional]: Number of states in the chain.
     '''
     MDP.__init__(self,
                  ChainMDP.ACTIONS,
                  self._transition_func,
                  self._reward_func,
                  init_state=ChainState(1),
                  gamma=gamma)
     self.num_states = num_states
     self.reset_val = reset_val
예제 #5
0
    def _transition_func(self, state, action):
        '''
        Args:
            state (State)
            action (str)

        Returns
            (State)
        '''
        # print(state.num, self.num_states, action, self.combo[state.num])
        if int(action) == self.combo[state.num - 1]:
            if state < self.num_states:
                return state + 1
            else:
                # At end of chain.
                return state
        else:
            return ChainState(1)
    def _transition_func(self, state, action):
        '''
        Args:
            state (State)
            action (str)

        Returns
            (State)
        '''
        if action == "forward":
            if state < self.num_states:
                return state + 1
            else:
                return state
        elif action == "reset":
            return ChainState(1)
        else:
            raise ValueError("(simple_rl Error): Unrecognized action! (" + action + ")")