Exemplo n.º 1
0
 def end_episode(self, reward, observation):
   self._exploration_temperature *= self._base_exploration_temperature
   self._exploration_functions = {
       'epsilon_greedy':
           lambda observation: agent_utils.epsilon_greedy_exploration(  # pylint: disable=g-long-lambda
               self._enumerate_state_action_indices(observation), self.
               _q_value_table, self._exploration_temperature),
       'min_count':
           lambda observation: agent_utils.min_count_exploration(  # pylint: disable=g-long-lambda
               observation, self._state_action_counts)
   }
   self._previous_state_action_index = None
Exemplo n.º 2
0
    def __init__(self,
                 observation_space,
                 action_space,
                 eval_mode=False,
                 ignore_response=True,
                 discretization_bounds=(0.0, 10.0),
                 number_bins=100,
                 exploration_policy='epsilon_greedy',
                 exploration_temperature=0.99,
                 learning_rate=0.1,
                 gamma=0.99,
                 **kwargs):
        """TabularQAgent init.

    Args:
      observation_space: a gym.spaces object specifying the format of
        observations.
      action_space: a gym.spaces object that specifies the format of actions.
      eval_mode: Boolean indicating whether the agent is in training or eval
        mode.
      ignore_response: Boolean indicating whether the agent should ignore the
        response part of the observation.
      discretization_bounds: pair of real numbers indicating the min and max
        value for continuous attributes discretization. Values below the min
        will all be grouped in the first bin, while values above the max will
        all be grouped in the last bin. See the documentation of numpy.digitize
        for further details.
      number_bins: positive integer number of bins used to discretize continuous
        attributes.
      exploration_policy: either one of ['epsilon_greedy', 'min_count'] or a
        custom function. TODO(mmladenov): formalize requirements of this
          function.
      exploration_temperature: a real number passed as parameter to the
        exploration policy.
      learning_rate: a real number between 0 and 1 indicating how much to update
        Q-values, i.e. Q_t+1(s,a) = (1 - learning_rate) * Q_t(s, a)
                                     + learning_rate * (R(s,a) + ...).
      gamma: real value between 0 and 1 indicating the discount factor of the
        MDP.
      **kwargs: additional arguments like eval_mode.
    """
        self._kwargs = kwargs
        super(TabularQAgent, self).__init__(action_space)
        # hard params
        self._gamma = gamma
        self._eval_mode = eval_mode
        self._previous_slate = None
        self._learning_rate = learning_rate
        # storage
        self._q_value_table = {}
        self._state_action_counts = {}
        self._previous_state_action_index = None
        # discretization and spaces
        self._discretization_bins = np.linspace(discretization_bounds[0],
                                                discretization_bounds[1],
                                                num=number_bins)
        single_doc_space = list(
            observation_space.spaces['doc'].spaces.values())[0]
        slate_tuple = tuple([single_doc_space] * self._slate_size)
        action_space = spaces.Tuple(slate_tuple)
        self._ignore_response = ignore_response
        state_action_space = {
            'user': observation_space.spaces['user'],
            'action': action_space
        }
        if not self._ignore_response:
            state_action_space['response'] = observation_space.spaces[
                'response']
        self._state_action_space = spaces.Dict(state_action_space)
        self._observation_featurizer = agent_utils.GymSpaceWalker(
            self._state_action_space, self._discretize_gym_leaf)
        # exploration
        self._exploration_policy = exploration_policy
        self._exploration_temperature = exploration_temperature
        self._base_exploration_temperature = self._exploration_temperature
        self._exploration_functions = {
            'epsilon_greedy':
            lambda observation: agent_utils.epsilon_greedy_exploration(  # pylint: disable=g-long-lambda
                self._enumerate_state_action_indices(observation), self.
                _q_value_table, self._exploration_temperature),
            'min_count':
            lambda observation: agent_utils.min_count_exploration(  # pylint: disable=g-long-lambda
                self._enumerate_state_action_indices(observation), self.
                _state_action_counts)
        }