Exemplo n.º 1
0
    def __init__(self,
                 n_frames_per_action=4,
                 trace_type='replacing',
                 learning_rate=0.001,
                 discount=0.99,
                 lambda_v=0.5):
        super(Sarsa2Agent, self).__init__(name='Sarsa2', version='2')
        self.n_frames_per_action = n_frames_per_action

        self.epsilon = LinearInterpolationManager([(0, 1.0), (1e4, 0.005)])
        self.action_repeat_manager = RepeatManager(n_frames_per_action - 1)

        self.trace_type = trace_type
        self.learning_rate = learning_rate
        self.lambda_v = lambda_v
        self.discount = discount

        self.q_vals = None
        self.e_vals = None

        self.initialize_asr_and_counters()
Exemplo n.º 2
0
    def __init__(self, n_frames_per_action=4, 
                 trace_type='replacing', 
                 learning_rate=0.001,
                 discount=0.99, 
                 lambda_v=0.5,
                 record=False):
        super(SarsaAgent, self).__init__(name='Sarsa', version='1')
        self.n_frames_per_action = n_frames_per_action

        self.epsilon = LinearInterpolationManager([(0, 1.0), (1e4, 0.005)])
        self.action_repeat_manager = RepeatManager(n_frames_per_action - 1)
        
        self.trace_type = trace_type
        self.learning_rate = learning_rate
        self.lambda_v = lambda_v
        self.discount = discount

        self.a_ = 0
        self.s_ = 0
        self.r_ = 0

        self.q_vals = None
        self.e_vals = None

        self.n_goals = 0
        self.n_greedy = 0
        self.n_random = 0

        self.record = record
        if record:
            # 5 action, 3 states 
            # => q_vals.shape == (5, 3)
            #    e_vals.shape == (5, 3)
            #    sarsa.shape == (5, 1)
            self.mem = CircularList(100000) 

        self.n_rr = 0
        self.n_sa = 0

        self.n_episode = 0
Exemplo n.º 3
0
    def __init__(self, n_frames_per_action=4, 
                 trace_type='replacing', 
                 learning_rate=0.001,
                 discount=0.99, 
                 lambda_v=0.5):
        super(Sarsa2Agent, self).__init__(name='Sarsa2', version='2')
        self.n_frames_per_action = n_frames_per_action

        self.epsilon = LinearInterpolationManager([(0, 1.0), (1e4, 0.005)])
        self.action_repeat_manager = RepeatManager(n_frames_per_action - 1)
        
        self.trace_type = trace_type
        self.learning_rate = learning_rate
        self.lambda_v = lambda_v
        self.discount = discount

        self.q_vals = None
        self.e_vals = None

        self.initialize_asr_and_counters()
Exemplo n.º 4
0
class Sarsa2Agent(Agent):
    """
    Agent that uses a SARSA(lambda)
    Input RAW image is preprocessed, resulting in states
    + Predicted ball position above player
    + Predicted ball position within the player pad
    + Predicted ball position beneath player
    """


    def __init__(self, n_frames_per_action=4, 
                 trace_type='replacing', 
                 learning_rate=0.001,
                 discount=0.99, 
                 lambda_v=0.5):
        super(Sarsa2Agent, self).__init__(name='Sarsa2', version='2')
        self.n_frames_per_action = n_frames_per_action

        self.epsilon = LinearInterpolationManager([(0, 1.0), (1e4, 0.005)])
        self.action_repeat_manager = RepeatManager(n_frames_per_action - 1)
        
        self.trace_type = trace_type
        self.learning_rate = learning_rate
        self.lambda_v = lambda_v
        self.discount = discount

        self.q_vals = None
        self.e_vals = None

        self.initialize_asr_and_counters()

    def initialize_asr_and_counters(self):
        self.a_ = 0
        self.s_ = 0
        self.r_ = 0

        self.n_goals = 0
        self.n_greedy = 0
        self.n_random = 0
        self.n_rr = 0
        self.n_sa = 0

        self.n_episode = 0

    def reset(self):
        self.q_vals[:] = 0.0
        self.e_vals[:] = 0.0
        self.epsilon.reset()
        self.initialize_asr_and_counters()

    def select_action(self):
        """
        Initialize Q(s; a) arbitrarily, for all s in S; a in A(s)
        Repeat (for each episode):
            E(s; a) = 0, for all s 2 S; a 2 A(s)
            Initialize S, A
            Repeat (for each step of episode):
              S = S'; A = A'
              Take action A, observe R, S'
              Choose A' from S' using policy derived from Q (e.g., e-greedy)
              update_q()
            until S is terminal
        """
        self.n_sa += 1

        sid = self.preprocessor.process()

        # assign previous s' to the current s
        s = self.s_
        # assign previous a' to the current a
        a = self.a_
        # get current state
        s_ = self.state_mapping[str(sid)]

        r = self.r_

        # select action:
        # - repeat previous action based on the n_frames_per_action param
        # - OR choose an action according to the e-greedy policy 
        a_ = self.action_repeat_manager.next()
        if a_ is None:
            a_ = self.e_greedy(s_)
            self.action_repeat_manager.set(a_)

        # Calculate update delta
        d = r + self.discount * self.q_vals[s_, a_] - self.q_vals[s, a]

        # Handle traces
        self.update_trace(s,a)

        # TODO: currently Q(s, a) is updated for all a, not a in A(s)!
        self.q_vals += self.learning_rate * d * self.e_vals
        self.e_vals *= (self.discount * self.lambda_v)

        # save current state, action for next iteration
        self.s_ = s_
        self.a_ = a_

        # save the state
        self.rlogger.write(self.n_episode, 
                           *[q for q in list(self.q_vals.flatten())
                             + list(self.e_vals.flatten())])

        return self.available_actions[a_]

    def set_results_dir(self, results_dir):
        super(Sarsa2Agent, self).set_results_dir(results_dir)

    def update_trace(self, s, a):
        if self.trace_type is 'accumulating':
            self.e_vals[s,a] += 1
        elif self.trace_type is 'replacing':
            self.e_vals[s,a] = 1
        elif self.trace_type is 'dutch':
            self.e_vals[s,a] *= (1 - self.learning_rate)
            self.e_vals[s,a] += 1

    def e_greedy(self, sid):
        """Returns action index
        """
        # decide on next action a'
        # E-greedy strategy
        if np.random.random() < self.epsilon.next(): 
            action = self.get_random_action()
            action = np.argmax(self.available_actions == action)
            self.n_random += 1
            # get the best action given the current state
        else:
            action = np.argmax(self.q_vals[sid, :])
            #print "greedy action {} from {}".format(action, self.q_vals[sid,:])
            self.n_greedy += 1
        return action

    def set_available_actions(self, actions):
        super(Sarsa2Agent, self).set_available_actions(actions)

        states = self.preprocessor.enumerate_states()
        state_n = len(states)

        # generate state to q_val index mapping
        self.state_mapping = dict([('{}'.format(v), i) 
                                    for i, v in enumerate(states)])
        print "Agent state_mapping:", self.state_mapping

        print 'state_n',state_n
        print 'actions',actions
        self.q_vals = np.zeros((state_n, len(actions)))
        self.e_vals = np.zeros((state_n, len(actions)))

        headers = 'episode'
        for q in range(len(self.q_vals.flatten())):
            headers += ',q{}'.format(q)
        for e in range(len(self.e_vals.flatten())):
            headers += ',e{}'.format(e)
        self.rlogger = CSVLogger(self.results_dir + '/q_e.csv', 
                                 headers, print_items=False)


    def set_raw_state_callbacks(self, state_functions):
        self.preprocessor = RelativeIntercept(state_functions)

    def receive_reward(self, reward):
        #print "receive_reward {}".format(self.n_rr)
        self.n_rr += 1
        self.r_ = reward
        if reward > 0:
            self.n_goals += 1

    def on_episode_start(self):
        self.n_goals = 0
        self.n_greedy = 0
        self.n_random = 0

    def on_episode_end(self):
        self.n_episode += 1
        #print "  q(s): {}".format(self.q_vals)
        #print "  e(s): {}".format(self.e_vals)
        #print "  goals: {}".format(self.n_goals)
        #print "  n_greedy: {}".format(self.n_greedy)
        #print "  n_random: {}".format(self.n_random)


    def get_settings(self):
        settings =  {
            "name": self.name,
            "version": self.version,
            "preprocessor": self.preprocessor.get_settings(),
            "n_frames_per_action": self.n_frames_per_action,
            "learning_rate": self.learning_rate,
            "discount_rate": self.discount, 
            "lambda": self.lambda_v,
        }

        settings.update(super(Sarsa2Agent, self).get_settings())
        
        return settings
Exemplo n.º 5
0
class Sarsa2Agent(Agent):
    """
    Agent that uses a SARSA(lambda)
    Input RAW image is preprocessed, resulting in states
    + Predicted ball position above player
    + Predicted ball position within the player pad
    + Predicted ball position beneath player
    """
    def __init__(self,
                 n_frames_per_action=4,
                 trace_type='replacing',
                 learning_rate=0.001,
                 discount=0.99,
                 lambda_v=0.5):
        super(Sarsa2Agent, self).__init__(name='Sarsa2', version='2')
        self.n_frames_per_action = n_frames_per_action

        self.epsilon = LinearInterpolationManager([(0, 1.0), (1e4, 0.005)])
        self.action_repeat_manager = RepeatManager(n_frames_per_action - 1)

        self.trace_type = trace_type
        self.learning_rate = learning_rate
        self.lambda_v = lambda_v
        self.discount = discount

        self.q_vals = None
        self.e_vals = None

        self.initialize_asr_and_counters()

    def initialize_asr_and_counters(self):
        self.a_ = 0
        self.s_ = 0
        self.r_ = 0

        self.n_goals = 0
        self.n_greedy = 0
        self.n_random = 0
        self.n_rr = 0
        self.n_sa = 0

        self.n_episode = 0

    def reset(self):
        self.q_vals[:] = 0.0
        self.e_vals[:] = 0.0
        self.epsilon.reset()
        self.initialize_asr_and_counters()

    def select_action(self):
        """
        Initialize Q(s; a) arbitrarily, for all s in S; a in A(s)
        Repeat (for each episode):
            E(s; a) = 0, for all s 2 S; a 2 A(s)
            Initialize S, A
            Repeat (for each step of episode):
              S = S'; A = A'
              Take action A, observe R, S'
              Choose A' from S' using policy derived from Q (e.g., e-greedy)
              update_q()
            until S is terminal
        """
        self.n_sa += 1

        sid = self.preprocessor.process()

        # assign previous s' to the current s
        s = self.s_
        # assign previous a' to the current a
        a = self.a_
        # get current state
        s_ = self.state_mapping[str(sid)]

        r = self.r_

        # select action:
        # - repeat previous action based on the n_frames_per_action param
        # - OR choose an action according to the e-greedy policy
        a_ = self.action_repeat_manager.next()
        if a_ is None:
            a_ = self.e_greedy(s_)
            self.action_repeat_manager.set(a_)

        # Calculate update delta
        d = r + self.discount * self.q_vals[s_, a_] - self.q_vals[s, a]

        # Handle traces
        self.update_trace(s, a)

        # TODO: currently Q(s, a) is updated for all a, not a in A(s)!
        self.q_vals += self.learning_rate * d * self.e_vals
        self.e_vals *= (self.discount * self.lambda_v)

        # save current state, action for next iteration
        self.s_ = s_
        self.a_ = a_

        # save the state
        self.rlogger.write(
            self.n_episode, *[
                q for q in list(self.q_vals.flatten()) +
                list(self.e_vals.flatten())
            ])

        return self.available_actions[a_]

    def set_results_dir(self, results_dir):
        super(Sarsa2Agent, self).set_results_dir(results_dir)

    def update_trace(self, s, a):
        if self.trace_type is 'accumulating':
            self.e_vals[s, a] += 1
        elif self.trace_type is 'replacing':
            self.e_vals[s, a] = 1
        elif self.trace_type is 'dutch':
            self.e_vals[s, a] *= (1 - self.learning_rate)
            self.e_vals[s, a] += 1

    def e_greedy(self, sid):
        """Returns action index
        """
        # decide on next action a'
        # E-greedy strategy
        if np.random.random() < self.epsilon.next():
            action = self.get_random_action()
            action = np.argmax(self.available_actions == action)
            self.n_random += 1
            # get the best action given the current state
        else:
            action = np.argmax(self.q_vals[sid, :])
            #print "greedy action {} from {}".format(action, self.q_vals[sid,:])
            self.n_greedy += 1
        return action

    def set_available_actions(self, actions):
        super(Sarsa2Agent, self).set_available_actions(actions)

        states = self.preprocessor.enumerate_states()
        state_n = len(states)

        # generate state to q_val index mapping
        self.state_mapping = dict([('{}'.format(v), i)
                                   for i, v in enumerate(states)])
        print "Agent state_mapping:", self.state_mapping

        print 'state_n', state_n
        print 'actions', actions
        self.q_vals = np.zeros((state_n, len(actions)))
        self.e_vals = np.zeros((state_n, len(actions)))

        headers = 'episode'
        for q in range(len(self.q_vals.flatten())):
            headers += ',q{}'.format(q)
        for e in range(len(self.e_vals.flatten())):
            headers += ',e{}'.format(e)
        self.rlogger = CSVLogger(self.results_dir + '/q_e.csv',
                                 headers,
                                 print_items=False)

    def set_raw_state_callbacks(self, state_functions):
        self.preprocessor = RelativeIntercept(state_functions)

    def receive_reward(self, reward):
        #print "receive_reward {}".format(self.n_rr)
        self.n_rr += 1
        self.r_ = reward
        if reward > 0:
            self.n_goals += 1

    def on_episode_start(self):
        self.n_goals = 0
        self.n_greedy = 0
        self.n_random = 0

    def on_episode_end(self):
        self.n_episode += 1
        #print "  q(s): {}".format(self.q_vals)
        #print "  e(s): {}".format(self.e_vals)
        #print "  goals: {}".format(self.n_goals)
        #print "  n_greedy: {}".format(self.n_greedy)
        #print "  n_random: {}".format(self.n_random)

    def get_settings(self):
        settings = {
            "name": self.name,
            "version": self.version,
            "preprocessor": self.preprocessor.get_settings(),
            "n_frames_per_action": self.n_frames_per_action,
            "learning_rate": self.learning_rate,
            "discount_rate": self.discount,
            "lambda": self.lambda_v,
        }

        settings.update(super(Sarsa2Agent, self).get_settings())

        return settings
Exemplo n.º 6
0
class SLAgent(Agent):
    """Agent using keras NN
    """

    def __init__(self, n_frames_per_action=4):
        super(SLAgent, self).__init__(name="SL", version="1")
        self.experience = CircularList(1000)
        self.epsilon = LinearInterpolationManager([(0, 1.0), (1e4, 0.1)])
        self.action_repeat_manager = RepeatManager(n_frames_per_action - 1)

    def select_action(self):
        # Repeat last chosen action?
        action = self.action_repeat_manager.next()
        if action != None:
            return action

        state = self.preprocessor.process()
        try:
            s = np.array(state).reshape(len(state), 1)
        except:
            s = np.array(state).reshape(1, 1)

        if self._sars[2]:
            self._sars[3] = s
            self.flush_experience()

        # Consider postponing the first training until we have 32 samples
        if len(self.experience) > 0:
            self.nn.train(self.experience)

        if np.random.random() < self.epsilon.next():
            action = self.get_random_action()
        else:
            action_index = self.nn.predict(s)
            action = self.available_actions[action_index]

        self.action_repeat_manager.set(action)

        self._sars[0] = s
        self._sars[1] = self.available_actions.index(action)

        return action

    def set_available_actions(self, actions):
        super(SLAgent, self).set_available_actions(actions)
        # possible state values
        state_n = len(self.preprocessor.enumerate_states())

        self.nn = MLP(config="simple", input_ranges=[[0, state_n]], n_outputs=len(actions), batch_size=4)

    def set_raw_state_callbacks(self, state_functions):
        self.preprocessor = StateIndex(RelativeBall(state_functions, trinary=True))

    def receive_reward(self, reward):
        self._sars[2] = reward

    def on_episode_start(self):
        self._reset_sars()

    def on_episode_end(self):
        self._sars[3] = self._sars[0]
        self._sars[4] = 0
        self.flush_experience()

    def flush_experience(self):
        self.experience.append(tuple(self._sars))
        self._reset_sars()

    def _reset_sars(self):
        # state, action, reward, newstate, newstate_not_terminal
        self._sars = [None, None, None, None, 1]

    def get_settings(self):
        settings = {
            "name": self.name,
            "version": self.version,
            "experience_replay": self.experience.capacity(),
            "preprocessor": self.preprocessor.get_settings(),
            "epsilon": self.epsilon.get_settings(),
            "nn": self.nn.get_settings(),
        }

        settings.update(super(SLAgent, self).get_settings())

        return settings
Exemplo n.º 7
0
 def __init__(self, n_frames_per_action=4):
     super(SLAgent, self).__init__(name="SL", version="1")
     self.experience = CircularList(1000)
     self.epsilon = LinearInterpolationManager([(0, 1.0), (1e4, 0.1)])
     self.action_repeat_manager = RepeatManager(n_frames_per_action - 1)
Exemplo n.º 8
0
class SLAgent(Agent):
    """Agent using keras NN
    """
    def __init__(self, n_frames_per_action=4):
        super(SLAgent, self).__init__(name='SL', version='1')
        self.experience = CircularList(1000)
        self.epsilon = LinearInterpolationManager([(0, 1.0), (1e4, 0.1)])
        self.action_repeat_manager = RepeatManager(n_frames_per_action - 1)

    def select_action(self):
        # Repeat last chosen action?
        action = self.action_repeat_manager.next()
        if action != None:
            return action

        state = self.preprocessor.process()
        try:
            s = np.array(state).reshape(len(state), 1)
        except:
            s = np.array(state).reshape(1, 1)

        if self._sars[2]:
            self._sars[3] = s
            self.flush_experience()

        # Consider postponing the first training until we have 32 samples
        if len(self.experience) > 0:
            self.nn.train(self.experience)

        if np.random.random() < self.epsilon.next():
            action = self.get_random_action()
        else:
            action_index = self.nn.predict(s)
            action = self.available_actions[action_index]

        self.action_repeat_manager.set(action)

        self._sars[0] = s
        self._sars[1] = self.available_actions.index(action)

        return action

    def set_available_actions(self, actions):
        super(SLAgent, self).set_available_actions(actions)
        # possible state values
        state_n = len(self.preprocessor.enumerate_states())

        self.nn = MLP(config='simple',
                      input_ranges=[[0, state_n]],
                      n_outputs=len(actions),
                      batch_size=4)

    def set_raw_state_callbacks(self, state_functions):
        self.preprocessor = StateIndex(
            RelativeBall(state_functions, trinary=True))

    def receive_reward(self, reward):
        self._sars[2] = reward

    def on_episode_start(self):
        self._reset_sars()

    def on_episode_end(self):
        self._sars[3] = self._sars[0]
        self._sars[4] = 0
        self.flush_experience()

    def flush_experience(self):
        self.experience.append(tuple(self._sars))
        self._reset_sars()

    def _reset_sars(self):
        # state, action, reward, newstate, newstate_not_terminal
        self._sars = [None, None, None, None, 1]

    def get_settings(self):
        settings = {
            "name": self.name,
            "version": self.version,
            "experience_replay": self.experience.capacity(),
            "preprocessor": self.preprocessor.get_settings(),
            "epsilon": self.epsilon.get_settings(),
            "nn": self.nn.get_settings(),
        }

        settings.update(super(SLAgent, self).get_settings())

        return settings
Exemplo n.º 9
0
 def __init__(self, n_frames_per_action=4):
     super(SLAgent, self).__init__(name='SL', version='1')
     self.experience = CircularList(1000)
     self.epsilon = LinearInterpolationManager([(0, 1.0), (1e4, 0.1)])
     self.action_repeat_manager = RepeatManager(n_frames_per_action - 1)
Exemplo n.º 10
0
class SarsaAgent(Agent):
    """
    Agent that uses a SARSA(lambda)
    Input RGB image is preprocessed, resulting in states
    - (x, y) ball
    - y player
    - y opponent
    """


    def __init__(self, n_frames_per_action=4, 
                 trace_type='replacing', 
                 learning_rate=0.001,
                 discount=0.99, 
                 lambda_v=0.5,
                 record=False):
        super(SarsaAgent, self).__init__(name='Sarsa', version='1')
        self.n_frames_per_action = n_frames_per_action

        self.epsilon = LinearInterpolationManager([(0, 1.0), (1e4, 0.005)])
        self.action_repeat_manager = RepeatManager(n_frames_per_action - 1)
        
        self.trace_type = trace_type
        self.learning_rate = learning_rate
        self.lambda_v = lambda_v
        self.discount = discount

        self.a_ = 0
        self.s_ = 0
        self.r_ = 0

        self.q_vals = None
        self.e_vals = None

        self.n_goals = 0
        self.n_greedy = 0
        self.n_random = 0

        self.record = record
        if record:
            # 5 action, 3 states 
            # => q_vals.shape == (5, 3)
            #    e_vals.shape == (5, 3)
            #    sarsa.shape == (5, 1)
            self.mem = CircularList(100000) 

        self.n_rr = 0
        self.n_sa = 0

        self.n_episode = 0


    def reset(self):
        pass

    def select_action(self):
        #print "select_action {}".format(self.n_sa)
        self.n_sa += 1

        #if self.n_sa > 20:
        #import sys
        #sys.exit(0)
        """
        Initialize Q(s; a) arbitrarily, for all s in S; a in A(s)
        Repeat (for each episode):
            E(s; a) = 0, for all s 2 S; a 2 A(s)
            Initialize S, A
            Repeat (for each step of episode):
              S = S'; A = A'
              Take action A, observe R, S'
              Choose A' from S' using policy derived from Q (e.g., e-greedy)
              update_q()
            until S is terminal
        """
        sid = self.preprocessor.process()
        if sid == 0:
            return 0

        # assign previous s' to the current s
        s = self.s_
        # assign previous a' to the current a
        a = self.a_
        # get current state
        s_ = self.state_mapping[str(sid)]

        r = self.r_

        # select action:
        # - repeat previous action based on the n_frames_per_action param
        # - OR choose an action according to the e-greedy policy 
        a_ = self.action_repeat_manager.next()
        if a_ is None:
            a_ = self.e_greedy(s_)
            self.action_repeat_manager.set(a_)

        #print "running SARSA with {}".format([s, a, r, s_, a_])

        """
              d = R + gamma*Q(S', A') - Q(S, A)
              E(S,A) = E(S,A) + 1           (accumulating traces)
           or E(S,A) = (1 - a) * E(S,A) + 1 (dutch traces)
           or E(S;A) = 1                    (replacing traces)
              For all s in S; a in A(s):
                Q(s,a) = Q(s,a) + E(s,a)   
                E(s,a) = gamma * lambda * E(s,a)
        """
        d = r + self.discount * self.q_vals[s_, a_] - self.q_vals[s, a]
        if self.trace_type is 'accumulating':
            self.e_vals[s,a] += 1
        elif self.trace_type is 'replacing':
            self.e_vals[s,a] = 1
        elif self.trace_type is 'dutch':
            self.e_vals[s,a] *= (1 - self.learning_rate)
            self.e_vals[s,a] += 1

        # TODO: currently Q(s, a) is updated for all a, not a in A(s)!
        self.q_vals += self.learning_rate * d * self.e_vals
        self.e_vals *= (self.discount * self.lambda_v)

        #if r != 0:
        #    print "lr: {} d: {}".format(self.learning_rate, d)
        #    print "d q_vals\n{}".format(self.q_vals - p_q_vals)


        # save current state, action for next iteration
        self.s_ = s_
        self.a_ = a_

        # save the state
        self.rlogger.write(self.n_episode, *[q for q in list(self.q_vals.flatten()) + list(self.e_vals.flatten())])

        if self.record: 
            self.mem.append({'q_vals': np.copy(self.q_vals), 
                             'sarsa': (s, a, r, s_, a_)})

        return self.available_actions[a_]

    def set_results_dir(self, results_dir):
        super(SarsaAgent, self).set_results_dir(results_dir)

    def e_greedy(self, sid):
        """Returns action index
        """
        # decide on next action a'
        # E-greedy strategy
        if np.random.random() < self.epsilon.next(): 
            action = self.get_random_action()
            action = np.argmax(self.available_actions == action)
            self.n_random += 1
            # get the best action given the current state
        else:
            action = np.argmax(self.q_vals[sid, :])
            #print "greedy action {} from {}".format(action, self.q_vals[sid,:])
            self.n_greedy += 1
        return action

    def set_available_actions(self, actions):
        # remove NO-OP from available actions
        actions = np.delete(actions, 0)

        super(SarsaAgent, self).set_available_actions(actions)

        states = self.preprocessor.enumerate_states()
        state_n = len(states)

        # generate state to q_val index mapping
        self.state_mapping = dict([('{}'.format(v), i) for i, v in enumerate(states)])
        print self.state_mapping

        print 'state_n',state_n
        print 'actions',actions
        self.q_vals = np.zeros((state_n, len(actions)))
        self.e_vals = np.zeros((state_n, len(actions)))

        headers = 'episode'
        for q in range(len(self.q_vals.flatten())):
            headers += ',q{}'.format(q)
        for e in range(len(self.e_vals.flatten())):
            headers += ',e{}'.format(e)
        self.rlogger = CSVLogger(self.results_dir + '/q_e.csv', headers, print_items=False)


    def set_raw_state_callbacks(self, state_functions):
        self.preprocessor = RelativeIntercept(state_functions, mode='binary')

    def receive_reward(self, reward):
        #print "receive_reward {}".format(self.n_rr)
        self.n_rr += 1
        self.r_ = reward
        if reward > 0:
            self.n_goals += 1

    def on_episode_start(self):
        self.n_goals = 0
        self.n_greedy = 0
        self.n_random = 0

    def on_episode_end(self):
        self.n_episode += 1
        #print "  q(s): {}".format(self.q_vals)
        #print "  e(s): {}".format(self.e_vals)
        #print "  goals: {}".format(self.n_goals)
        #print "  n_greedy: {}".format(self.n_greedy)
        #print "  n_random: {}".format(self.n_random)

        if self.record:
            a_s = [(e['sarsa'][4], e['sarsa'][3]) for e in self.mem]
            a_counts = [0] * self.q_vals.shape[0]
            s_counts = [0] * self.q_vals.shape[1]
            for a, s in a_s:
                a_counts[a] += 1
                s_counts[s] += 1
            print "  actions: {}".format(a_counts)
            print "  states: {}".format(s_counts)

            self.mem.clear()

    def get_learning_dump(self):
        return self.mem

    def get_settings(self):
        settings =  {
            "name": self.name,
            "version": self.version,
            "preprocessor": self.preprocessor.get_settings(),
            "n_frames_per_action": self.n_frames_per_action,
            "learning_rate": self.learning_rate,
            "discount_rate": self.discount, 
            "lambda": self.lambda_v,
        }

        settings.update(super(SarsaAgent, self).get_settings())
        
        return settings