Ejemplo n.º 1
0
def q_learning(env, estimator, num_episodes, discount_factor=1.0, epsilon=0.1, epsilon_decay=1.0):
    """
    Q-Learning algorithm for off-policy TD control using Function Approximation.
    Finds the optimal greedy policy while following an epsilon-greedy policy.
    
    Args:
        env: OpenAI environment.
        estimator: Action-Value function estimator
        num_episodes: Number of episodes to run for.
        discount_factor: Gamma discount factor.
        epsilon: Chance the sample a random action. Float betwen 0 and 1.
        epsilon_decay: Each episode, epsilon is decayed by this factor
    
    Returns:
        An EpisodeStats object with two numpy arrays for episode_lengths and episode_rewards.
    """

    # Keeps track of useful statistics
    stats = plots.EpisodeStats(
        episode_lengths=np.zeros(num_episodes),
        episode_rewards=np.zeros(num_episodes))    
    
    for i_episode in range(num_episodes):
        
        # The policy we're following
        policy = make_epsilon_greedy_policy(
            estimator, epsilon * epsilon_decay**i_episode, env.action_space.n)
        
        # Print out which episode we're on, useful for debugging.
        # Also print reward for last episode
        last_reward = stats.episode_rewards[i_episode - 1]
        sys.stdout.flush()
        
        state = env.reset()
        
        for t in itertools.count():
            # sample action
            action_probs = policy(state)
            action = np.random.choice(np.arange(len(action_probs)),
                                      p=action_probs)
            
            next_state, reward, is_done, _ = env.step(action)
            
            # Update statistics
            stats.episode_rewards[i_episode] += reward
            stats.episode_lengths[i_episode] = t
            
            q_values_next = estimator.predict(next_state)
            td_target = reward + discount_factor * np.max(q_values_next)
            estimator.update(state, action, td_target)
            print("\rStep {} @ Episode {}/{} ({})".format(t, i_episode + 1, num_episodes, last_reward), end="")
            if is_done:
                break
            
            state = next_state
    return stats
Ejemplo n.º 2
0
def n_step_sarsa(env,
                 num_episodes,
                 n=5,
                 discount_factor=1.0,
                 alpha=0.5,
                 epsilon=0.1):
    """
    (n step)SARSA algorithm: On-policy TD control. Finds the optimal epsilon-greedy policy.
    The algorithm looks forward n steps and then bootstraps from there to update the Q values.
    
    Args:
        env: OpenAI environment.
        num_episodes: Number of episodes to run for.
        n: future time steps to look ahead and calculate return for.
        discount_factor: Gamma discount factor.
        alpha: TD learning rate.
        epsilon: Chance the sample a random action. Float betwen 0 and 1.
    
    Returns:
        A tuple (Q, stats).
        Q is the optimal action-value function, a dictionary mapping state -> action values.
        stats is an EpisodeStats object with two numpy arrays for episode_lengths and episode_rewards.
    NOTE: some parts taken from https://github.com/Breakend/MultiStepBootstrappingInRL/blob/master/n_step_sarsa.py
    """

    # The final action-value function.
    # A nested dictionary that maps state -> (action -> action-value).
    Q = defaultdict(lambda: np.zeros(env.action_space.n))

    # Keeps track of useful statistics
    stats = plots.EpisodeStats(episode_lengths=np.zeros(num_episodes),
                               episode_rewards=np.zeros(num_episodes))

    # The policy we're following
    policy = create_epsilon_greedy_policy(Q, epsilon, env.action_space.n)

    for i_episode in range(num_episodes):
        # Print out which episode we're on, useful for debugging.
        if (i_episode + 1) % 10 == 0:
            print("\rEpisode {}/{}.".format(i_episode + 1, num_episodes),
                  end="")
            sys.stdout.flush()

        # initializations
        T = sys.maxsize
        tau = 0
        t = -1
        stored_actions = {}
        stored_rewards = {}
        stored_states = {}

        # initialize first state
        state = env.reset()
        action_probs = policy(state)
        action = np.random.choice(env.action_space.n, p=action_probs)

        stored_actions[0] = action
        stored_states[0] = state

        while tau < (T - 1):
            t += 1
            if t < T:
                state, reward, done, _ = env.step(action)

                stored_rewards[(t + 1) % (n + 1)] = reward
                stored_states[(t + 1) % (n + 1)] = state

                # Update statistics
                stats.episode_rewards[i_episode] += reward
                stats.episode_lengths[i_episode] = t

                if done:
                    T = t + 1
                else:
                    next_action_probs = policy(state)
                    action = np.random.choice(env.action_space.n,
                                              p=next_action_probs)
                    stored_actions[(t + 1) % (n + 1)] = action
            tau = t - n + 1

            if tau >= 0:
                # calculate G(tau:tau+n)
                G = np.sum([
                    discount_factor**(i - tau - 1) * stored_rewards[i %
                                                                    (n + 1)]
                    for i in range(tau + 1,
                                   min(tau + n, T) + 1)
                ])

                if tau + n < T:
                    G += discount_factor**n * Q[stored_states[
                        (tau + n) % (n + 1)]][stored_actions[(tau + n) %
                                                             (n + 1)]]

                tau_s, tau_a = stored_states[tau %
                                             (n + 1)], stored_actions[tau %
                                                                      (n + 1)]

                # update Q value with n step return
                Q[tau_s][tau_a] += alpha * (G - Q[tau_s][tau_a])

    return Q, stats
Ejemplo n.º 3
0
def sarsa_lambd(env, num_episodes, discount_factor=1.0, alpha=0.5, epsilon=0.1, lambd=0.9):
    """
    SARSA algorithm: On-policy TD control. Finds the optimal epsilon-greedy policy.
    
    Args:
        env: OpenAI environment.
        num_episodes: Number of episodes to run for.
        discount_factor: Gamma discount factor.
        alpha: TD learning rate.
        epsilon: Chance the sample a random action. Float betwen 0 and 1.
    
    Returns:
        A tuple (Q, stats).
        Q is the optimal action-value function, a dictionary mapping state -> action values.
        stats is an EpisodeStats object with two numpy arrays for episode_lengths and episode_rewards.
    """
    nA = env.action_space.n
    
    # The final action-value function.
    # A nested dictionary that maps state -> (action -> action-value).
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    
    # Keeps track of useful statistics
    stats = plots.EpisodeStats(
        episode_lengths=np.zeros(num_episodes),
        episode_rewards=np.zeros(num_episodes))

    # The policy we're following
    policy = create_epsilon_greedy_policy(Q, epsilon, env.action_space.n)
    

    for i_episode in range(num_episodes):
        # Print out which episode we're on, useful for debugging.
        if (i_episode + 1) % 100 == 0:
            print("\rEpisode {}/{}.".format(i_episode + 1, num_episodes), end="")
            sys.stdout.flush()
            
        E = defaultdict(lambda: np.zeros(env.action_space.n))
        state = env.reset()
        action_probs = policy(state)
        action = np.random.choice(env.action_space.n, p=action_probs)
        for t in itertools.count():
            # environment efforts after taking action
            next_state, reward, done, _ = env.step(action)
            
            next_action_probs = policy(next_state)
            
            next_action = np.random.choice(env.action_space.n, p=next_action_probs)
            
            # Update statistics
            stats.episode_rewards[i_episode] += reward
            stats.episode_lengths[i_episode] = t
            
            td_error = reward + (discount_factor * Q[next_state][next_action]) - Q[state][action]
            E[state][action] += 1
            
            for s, _ in Q.items():
                for a_ in range(nA):
                    Q[s][a_] += alpha * td_error * E[s][a_]
                    E[s][a_] *= discount_factor * lambd
            
            if done:
                break
            
            state = next_state
            action = next_action
    
    return Q, stats
Ejemplo n.º 4
0
def n_step_expected_sarsa(env, num_episodes, n=10, gamma=0.9, alpha=0.1, epsilon=0.3):
    """
    n step Expected SARSA algorithm: Off-policy TD control. Finds the optimal target policy.
    
    Args:
        env: OpenAI environment.
        num_episodes: Number of episodes to run for.
        n: future time steps to look ahead and calculate return for.
        gamma: Gamma discount factor.
        alpha: TD learning rate.
        epsilon: Chance the sample a random action. Float betwen 0 and 1.
    
    Returns:
        A tuple (Q, stats).
        Q is the optimal action-value function, a dictionary mapping state -> action values.
        stats is an EpisodeStats object with two numpy arrays for episode_lengths and episode_rewards.
    NOTE: some parts taken from https://github.com/Breakend/MultiStepBootstrappingInRL/blob/master/n_step_sarsa.py
    """
    # initializations
    # number of actions
    nA = env.action_space.n
    
    # create Q dict. A nested dict that maps state-> action values
    Q = defaultdict(lambda: np.zeros(nA))
    
    # policy we are following which is more 
    # exploratory and less greedy
    behavior_policy = create_behavior_policy(Q, nA)
    
    # Policy we are learning. Less exploratory
    # than behavior policy -> means more greedy.
    target_policy = create_target_policy(Q, nA)
    
    # Keeps track of useful statistics
    stats = plots.EpisodeStats(
        episode_lengths=np.zeros(num_episodes),
        episode_rewards=np.zeros(num_episodes))
    
    max_reward = 0
    total_reward = 0
    rewards_per_episode = []
    q_variance = []
    
    for i_episode in range(num_episodes):
        # Print out which episode we're on, useful for debugging.
        if (i_episode + 1) % 10 == 0:
            print("\rEpisode {}/{}.".format(i_episode + 1, num_episodes), end="")
            sys.stdout.flush()
        
        T = sys.maxsize
        tau = 0
        t = -1
        
        stored_actions = {}
        stored_rewards = {}
        stored_states = {}
        
        # reset env to get initial state
        state = env.reset()
    
        # get action probs from behavior policy
        action_probs = behavior_policy(state)
        action = np.random.choice(np.arange(nA), p=action_probs)
        
        stored_actions[0] = action
        stored_states[0] = state
        
        while tau < (T - 1):
            t += 1
            if t < T:
                state, reward, done, _ = env.step(action)
                
                # Update statistics
                stats.episode_rewards[i_episode] += reward
                stats.episode_lengths[i_episode] = t
                
                stored_rewards[(t+1) % (n+1)] = reward
                stored_states[(t+1) % (n+1)] = state
                
                if done:
                    T = t + 1
                else:
                    action_probs = behavior_policy(state)
                    action = np.random.choice(np.arange(nA), p=action_probs)
                    stored_actions[(t+1) % (n+1)] = action
            tau = t - n + 1
            if tau >= 0:
                # calculate rho
                rho = np.prod(
                    [target_policy(stored_states[i%(n+1)])[stored_actions[i%(n+1)]]/behavior_policy(stored_states[i%(n+1)])[stored_actions[i%(n+1)]] for i in range(tau+1, min(tau+n-1, T-1)+1)]
                    )
                
                # calculate return
                G = np.sum([(gamma**(i-tau-1))*stored_rewards[i%(n+1)] for i in range(tau+1, min(tau+n, T)+1)])
                
                
                if tau+n < T:
                    expected_sarsa_update = np.sum(
                        [target_policy(stored_states[(tau+n) % (n+1)])[a] * Q[stored_states[(tau+n) % (n+1)]][a] for a in range(nA)]
                    )
                    G += (gamma**n) * expected_sarsa_update
                    
                s_tau, a_tau = stored_states[tau % (n+1)], stored_actions[tau % (n+1)]
                
                td_error = G - Q[s_tau][a_tau]
                Q[s_tau][a_tau] += alpha * rho * td_error
    return Q, stats
Ejemplo n.º 5
0
def q_learning(env, num_episodes, discount_factor=1.0, alpha=0.5, epsilon=0.1):
    """
    Q-Learning algorithm: Off-policy TD control. Finds the optimal greedy policy
    while following an epsilon-greedy policy
    
    Args:
        env: OpenAI environment.
        num_episodes: Number of episodes to run for.
        discount_factor: Gamma discount factor.
        alpha: TD learning rate.
        epsilon: Chance to sample a random action. Float between 0 and 1.
    
    Returns:
        A tuple (Q, episode_lengths).
        Q is the optimal action-value function, a dictionary mapping state -> action values.
        stats is an EpisodeStats object with two numpy arrays for episode_lengths and episode_rewards.
    """

    # The final action-value function.
    # A nested dictionary that maps state -> (action -> action-value).
    Q = defaultdict(lambda: np.zeros(env.action_space.n))

    # Keeps track of useful statistics
    stats = plots.EpisodeStats(episode_lengths=np.zeros(num_episodes),
                               episode_rewards=np.zeros(num_episodes))

    # The policy we're following
    policy = create_epsilon_greedy_policy(Q, epsilon, env.action_space.n)

    for i_episode in range(num_episodes):
        # Print out which episode we're on, useful for debugging.
        if (i_episode + 1) % 100 == 0:
            print("\rEpisode {}/{}.".format(i_episode + 1, num_episodes),
                  end="")
            sys.stdout.flush()

        state = env.reset()

        for t in itertools.count():
            # sample action from behavior policy
            action_probs = policy(state)
            action = np.random.choice(env.action_space.n, p=action_probs)

            # take action and observe environment's effects
            next_state, reward, done, _ = env.step(action)

            # Update statistics
            stats.episode_rewards[i_episode] += reward
            stats.episode_lengths[i_episode] = t

            # sample next action from target policy
            next_action = np.argmax(Q[next_state])

            td_target = reward + discount_factor * Q[next_state][next_action]

            # update Q value
            Q[state][action] += alpha * (td_target - Q[state][action])

            if done:
                break

            state = next_state

    return Q, stats
 def train(self):
     stats = plots.EpisodeStats(
     episode_lengths=np.zeros(self.num_episodes),
     episode_rewards=np.zeros(self.num_episodes))
 
     Transition = namedtuple("Transition", ["state", "action", "reward", "next_state", "done"])
     for i_episode in range(self.num_episodes):
         state = env.reset()
         trajectory = list()
         for t in itertools.count():
             # get action prediction
             action_probs = self.policy_estimator.predict(state).detach().numpy()
             
             # get action
             action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
             
             # take step
             next_state, reward, done, _ = env.step(action)
             
             trajectory.append(Transition(state=state,
                                          action=action,
                                          reward=reward,
                                          next_state=next_state,
                                          done=done))
             
             stats.episode_lengths[i_episode] = t
             stats.episode_rewards[i_episode] += reward
             
             # Print out which step we're on, useful for debugging.
             print("\rStep {} @ Episode {}/{} ({})".format(
                 t, i_episode + 1, self.num_episodes, stats.episode_rewards[i_episode - 1]), end="")
             sys.stdout.flush()
             
             if done:
                 break
             
             state = next_state
             
         for t, transition in enumerate(trajectory):
             # get total reward
             total_return = sum(self.gamma**i * tr.reward for i, tr in enumerate(trajectory[t:]))
             
             # get value_estimate
             value_estimate = self.value_estimator.predict(transition.state).detach()
             
             advantage = torch.FloatTensor([total_return]) - value_estimate
             advantage = torch.FloatTensor([advantage])
             
             # update value estimator
             self.value_estimator.update(transition.state, 
                                         torch.FloatTensor([total_return]), 
                                         self.value_optimizer)
             
             # update policy estimator
             action = torch.LongTensor([transition.action])
             self.policy_estimator.update(transition.state, 
                                          advantage, 
                                          action, 
                                          self.policy_optimizer)
             
     return stats
Ejemplo n.º 7
0
def q_sigma(env, num_episodes, n=10, gamma=0.9, alpha=0.1):
    """
    n step q sigma algorithm: Off policy TD control.
    Args:
        env: OpenAI environment.
        num_episodes: Number of episodes to run for.
        n: future time steps to look ahead and calculate return for.
        gamma: Gamma discount factor.
        alpha: TD learning rate.
        epsilon: Chance the sample a random action. Float betwen 0 and 1.
    
    Returns:
        A tuple (Q, stats).
        Q is the optimal action-value function, a dictionary mapping state -> action values.
        stats is an EpisodeStats object with two numpy arrays for episode_lengths and episode_rewards.
    """
    nA = env.action_space.n

    # Create dict Q. A mapping from state to action values
    Q = defaultdict(lambda: np.zeros(nA))

    # policy we are following which is more
    # exploratory and less greedy
    behavior_policy = create_behavior_policy(Q, nA)

    # Policy we are learning. Less exploratory
    # than behavior policy -> means more greedy.
    target_policy = create_target_policy(Q, nA)

    # Keeps track of useful statistics
    stats = plots.EpisodeStats(episode_lengths=np.zeros(num_episodes),
                               episode_rewards=np.zeros(num_episodes))

    for i_episode in range(num_episodes):
        if (i_episode + 1) % 100 == 0:
            print("\rEpisode {}/{}".format(i_episode + 1, num_episodes),
                  end="")
            sys.stdout.flush()

        T = sys.maxsize
        t = -1
        tau = 0

        stored_actions = {}
        stored_states = {}
        stored_rewards = {}
        stored_rho = {}
        stored_sigma = {}

        state = env.reset()
        action_probs = behavior_policy(state)
        action = np.random.choice(np.arange(nA), p=action_probs)
        sigma = get_sigma(nA, action)
        rho = target_policy(state)[action] / action_probs[action]

        # store selected params
        stored_states[0] = state
        stored_actions[0] = action
        stored_rho[0] = rho
        stored_sigma[0] = sigma

        while tau < (T - 1):
            t += 1
            if t < T:
                # take action and observe envronment's effect
                state, reward, done, _ = env.step(action)

                stored_states[(t + 1) % (n + 1)] = state
                stored_rewards[(t + 1) % (n + 1)] = reward

                stats.episode_lengths[i_episode] = t
                stats.episode_rewards[i_episode] += reward

                if done:
                    T = t + 1
                else:
                    action_probs = behavior_policy(state)
                    action = np.random.choice(np.arange(nA), p=action_probs)
                    sigma = get_sigma(nA, action)
                    rho = target_policy(state)[action] / action_probs[action]

                    stored_actions[(t + 1) % (n + 1)] = action
                    stored_sigma[(t + 1) % (n + 1)] = sigma
                    stored_rho[(t + 1) % (n + 1)] = rho

            # tau is the time whose estimate is being updated
            tau = t - n + 1
            if tau >= 0:
                if t + 1 < T:
                    G = Q[stored_states[(t + 1) %
                                        (n + 1)]][stored_actions[(t + 1) %
                                                                 (n + 1)]]

                for k in range(min(t + 1, T), tau, -1):
                    if k == T:
                        G = stored_rewards[T % (n + 1)]
                    else:
                        s_k = stored_states[k % (n + 1)]
                        a_k = stored_actions[k % (n + 1)]
                        r_k = stored_rewards[k % (n + 1)]
                        sigma_k = stored_sigma[k % (n + 1)]
                        rho_k = stored_rho[k % (n + 1)]
                        v_ = np.sum([(target_policy(s_k)[a]) * Q[s_k][a]
                                     for a in range(nA)])
                        G = r_k + gamma * ((sigma_k * rho_k) +
                                           ((1 - sigma_k) *
                                            (target_policy(s_k)[a_k]))) * (
                                                G - Q[s_k][a_k]) + gamma * v_

                s_tau, a_tau = stored_states[tau %
                                             (n + 1)], stored_actions[tau %
                                                                      (n + 1)]
                td_error = G - Q[s_tau][a_tau]
                Q[s_tau][a_tau] += alpha * td_error

    return Q, stats
Ejemplo n.º 8
0
def n_step_tree_backup(env, num_episodes, n=10, gamma=0.9, alpha=0.1):
    """
    n step Tree Backup algorithm: Off policy TD control.
    Args:
        env: OpenAI environment.
        num_episodes: Number of episodes to run for.
        n: future time steps to look ahead and calculate return for.
        gamma: Gamma discount factor.
        alpha: TD learning rate.
        epsilon: Chance the sample a random action. Float betwen 0 and 1.
    
    Returns:
        A tuple (Q, stats).
        Q is the optimal action-value function, a dictionary mapping state -> action values.
        stats is an EpisodeStats object with two numpy arrays for episode_lengths and episode_rewards.
    """
    nA = env.action_space.n

    # Create dict Q. A mapping from state to action values
    Q = defaultdict(lambda: np.zeros(nA))

    # policy we are following which is more
    # exploratory and less greedy
    behavior_policy = create_behavior_policy(Q, nA)

    # Policy we are learning. Less exploratory
    # than behavior policy -> means more greedy.
    target_policy = create_target_policy(Q, nA)

    # Keeps track of useful statistics
    stats = plots.EpisodeStats(episode_lengths=np.zeros(num_episodes),
                               episode_rewards=np.zeros(num_episodes))

    for i_episode in range(num_episodes):
        if (i_episode + 1) % 100 == 0:
            print("\rEpisode {}/{}.".format(i_episode + 1, num_episodes),
                  end="")
            sys.stdout.flush()

        T = sys.maxsize
        tau = 0
        t = -1

        stored_actions = {}
        stored_rewards = {}
        stored_states = {}

        # reset env to get initial state
        state = env.reset()

        # get action probs from behavior policy
        action_probs = behavior_policy(state)
        action = np.random.choice(np.arange(nA), p=action_probs)

        stored_actions[0] = action
        stored_states[0] = state

        while tau < (T - 1):
            t += 1
            if t < T:
                # take action and observe effects of the environment
                state, reward, done, _ = env.step(action)

                stored_states[(t + 1) % (n + 1)] = state
                stored_rewards[(t + 1) % (n + 1)] = reward

                stats.episode_lengths[i_episode] = t
                stats.episode_rewards[i_episode] += reward

                if done:
                    T = t + 1
                else:
                    action_probs = behavior_policy(state)
                    action = np.random.choice(np.arange(nA), p=action_probs)
                    stored_actions[(t + 1) % (n + 1)] = action

            tau = t - n + 1
            if tau >= 0:
                if (t + 1) >= T:
                    G = stored_rewards[T % (n + 1)]
                else:
                    # get s[t+1] from stored states
                    s_t1 = stored_states[(t + 1) % (n + 1)]
                    # calulate sum of the leaf actions
                    leaf_sum = np.sum([(target_policy(s_t1)[a]) * Q[s_t1][a]
                                       for a in range(env.nA)])
                    G = stored_rewards[(t + 1) % (n + 1)] + gamma * leaf_sum

                for k in range(min(t, T - 1), tau, -1):
                    # get kth action and state
                    s_k, a_k = stored_states[k %
                                             (n + 1)], stored_actions[k %
                                                                      (n + 1)]
                    a_probs = np.sum([
                        target_policy(s_k)[a] * Q[s_k][a] for a in range(nA)
                        if a != a_k
                    ])
                    G = stored_rewards[k % (n + 1)] + gamma * (
                        a_probs + target_policy(s_k)[a_k] * G)

                s_tau, a_tau = stored_states[tau %
                                             (n + 1)], stored_actions[tau %
                                                                      (n + 1)]
                td_error = G - Q[s_tau][a_tau]
                Q[s_tau][a_tau] += alpha * td_error

    return Q, stats
Ejemplo n.º 9
0
def n_step_expected_sarsa(env,
                          num_episodes,
                          n=5,
                          gamma=0.9,
                          epsilon=0.1,
                          alpha=0.1):
    """
    (n step)Expected SARSA: On policy TD control. Finds the optimal epsilon greedy policy. The 
    algorithm is same as n step SARSA except that its last element is the branch over all action 
    possibilities weighted by their probabilities under pi(policy we are following).
    
    Args:
        env: The OpenAI environment.
        num_episodes: Number of episodes to run for.
        n: future time steps to look ahead and calculate return for.
        discount_factor: Gamma discount factor.
        alpha: TD learning rate.
        epsilon: Chance to sample a random action. Float betwen 0 and 1.
        
    Returns:
        A tuple (Q, state)
        Q: A dict mapping state -> action values. Q is the optimal action-value 
           function, a dictionary mapping state -> action values.
        stats: An EpisodeStats object with two numpy arrays for episode_lengths 
               and episode_rewards.
    """

    nA = env.action_space.n

    # The final action-value function.
    # A nested dictionary that maps state -> (action -> action-value).
    Q = defaultdict(lambda: np.zeros(nA))

    # policy we are following
    policy = create_epsilon_greedy_policy(Q, nA, epsilon)

    # track useful stats to plot
    stats = plots.EpisodeStats(
        episode_lengths=np.zeros(num_episodes),
        episode_rewards=np.zeros(num_episodes),
    )

    for i_episode in range(num_episodes):
        # print the current episode for debugging
        if (i_episode + 1) % 100 == 0:
            print("\rEpisode {}/{}.".format(i_episode + 1, num_episodes),
                  end="")
            sys.stdout.flush()

        # initializations
        stored_rewards = {}
        stored_states = {}
        stored_actions = {}

        T = sys.maxsize
        t = -1
        tau = 0

        # reset OpenAI env to get the initial state
        state = env.reset()

        # get action probabilities from the policy function
        action_probs = policy(state)

        # sample action according to the action probabilities
        action = np.random.choice(np.arange(nA), p=action_probs)

        # store current action and state
        stored_actions[0] = action
        stored_states[0] = state

        while tau < (T - 1):
            t += 1
            if t < T:
                # observe environments effects after taking sampled action
                next_state, reward, done, _ = env.step(action)

                # assign next_state to current state
                state = next_state

                stored_states[(t + 1) % (n + 1)] = state
                stored_rewards[(t + 1) % (n + 1)] = reward

                stats.episode_lengths[i_episode] = t
                stats.episode_rewards[i_episode] += reward

                if done:
                    T = t + 1
                else:
                    # select and store action A[t+1]
                    action_probs = policy(state)
                    action = np.random.choice(np.arange(nA), p=action_probs)
                    stored_actions[(t + 1) % (n + 1)] = action

            tau = t - n + 1
            if tau >= 0:
                # caluclate return
                G = np.sum([
                    (gamma**(i - tau - 1)) * stored_rewards[i % (n + 1)]
                    for i in range(tau + 1,
                                   min(tau + n, T) + 1)
                ])

                # this step we calculate value of all action possibilities weighted
                # by their probabilities under pi(policy we are following).
                if tau + n < T:
                    exp_sarsa_update = np.sum([
                        policy(stored_states[(tau + n) % (n + 1)])[a] *
                        Q[stored_states[(tau + n) % (n + 1)]][a]
                        for a in range(nA)
                    ])
                    G += (gamma**n) * exp_sarsa_update

                # update Q value here
                s_tau, a_tau = stored_states[tau %
                                             (n + 1)], stored_actions[tau %
                                                                      (n + 1)]
                Q[s_tau][a_tau] += alpha * (G - Q[s_tau][a_tau])

    return Q, stats