def sarsa(lamb: int, num_episodes: int, Qstar, record=False): Q = state_action_map(plus=True) N = state_action_map() N_s = state_map(plus=True) mses = [] for k in range(num_episodes): E = state_action_map() s = State(deal=True) a = get_e_greedy_action(Q, N_s, s) while not s.terminal(): N_s[s.get_state()] += 1 N[s.get_state(), a] += 1 s_dash, r = step(s, a) a_dash = get_e_greedy_action(Q, N_s, s_dash) delta = r + Q[s_dash.get_state(), a_dash] - Q[s.get_state(), a] E[s.get_state(), a] += 1 for d in DEALER_RANGE: for p in PLAYER_RANGE: for action in ACTIONS: Q[(d, p), action] += (1 / (N[(d, p), action] + 1e-9)) * delta * E[ (d, p), action] E[(d, p), action] *= lamb s = s_dash a = a_dash if record: mses.append(calc_mse(Q, Qstar)) return Q, mses
def sample_episode(pi): history = [] s = State(deal=True) while not s.terminal(): a = pi[s.get_state()] # rewards do not need to be appended to history as rewards are only *rewarded* when entering the terminal state. history.append([s.get_state(), a]) s, r = step(s, a) return history, r
def get_e_greedy_action(Q: dict, N: dict, state: State): epsilon = 100 / (100 + N[state.get_state()]) chosen_action = None if np.random.uniform() > epsilon: max_q = -1e9 for a in ACTIONS: q = Q[state.get_state(), a] if q > max_q: max_q = q chosen_action = a else: chosen_action = random.choice(ACTIONS) return chosen_action