def modified_policy_iteration(mdp: MDP,
                              gamma: float,
                              epsilon: float,
                              k: int = 5) -> Tuple[Dict, Dict]:
    random_a = random.sample(mdp.A, 1)[0]
    pi = {s: random_a for s in mdp.S}
    V = {s: 0. for s in mdp.S}
    while True:
        for i in range(k):
            for s in mdp.S:
                V[s] = mdp.R(s, pi[s]) + gamma * sum([
                    mdp.P(s_prime, s, pi[s]) * V[s_prime] for s_prime in mdp.S
                ])
        delta = 0.
        for s in mdp.S:
            V_old = V[s]
            V_new = {
                a: mdp.R(s, a) + gamma *
                sum([mdp.P(s_prime, s, a) * V[s_prime] for s_prime in mdp.S])
                for a in mdp.A
            }
            pi[s] = max(V_new, key=V_new.get)
            V[s] = max(V_new.values())
            delta = max(delta, abs(V[s] - V_old))
        if delta <= epsilon:
            break
    return pi, V
def async_value_iteration(mdp: MDP,
                          gamma: float,
                          num_iterations: int = 1000) -> Tuple[Dict, Dict]:
    Q = {(s, a): 0. for a in mdp.A for s in mdp.S}
    for i in range(num_iterations):
        s = random.sample(mdp.S, 1)[0]
        a = random.sample(mdp.A, 1)[0]
        Q[(s, a)] = mdp.R(s, a) + gamma * sum([
            mdp.P(s_prime, s, a) *
            max([Q[(s_prime, a_prime)] for a_prime in mdp.A])
            for s_prime in mdp.S
        ])
    pi = {}
    for s in mdp.S:
        values = {a: Q[(s, a)] for a in mdp.A}
        pi[s] = max(values, key=values.get)
    return pi, Q
def value_iteration(mdp: MDP, gamma: float,
                    epsilon: float) -> Tuple[Dict, Dict]:
    V = {s: (None, 0.) for s in mdp.S}
    while True:
        delta = 0.
        for s in mdp.S:
            V_old = V[s][1]
            V_new = {
                a: mdp.R(s, a) + gamma * sum([
                    mdp.P(s_prime, s, a) * V[s_prime][1] for s_prime in mdp.S
                ])
                for a in mdp.A
            }
            V[s] = (max(V_new, key=V_new.get), max(V_new.values()))
            delta = max(delta, abs(V[s][1] - V_old))
        if delta <= epsilon:
            break
    pi = {}
    for s in mdp.S:
        pi[s] = V[s][0]
        V[s] = V[s][1]
    return pi, V