def get_new_state_value(mdp, state_values, state, gamma): """ Computes next V(s) as in formula above. Please do not change state_values in process. """ if mdp.is_terminal(state): return 0 max = 0 actions = mdp.get_possible_actions(state) for i in actions: var = get_action_value(mdp, state_values, state, i, gamma) if var > max: max = var return max
def get_optimal_action_for_plot(mdp, state_values, state, gamma=0.9): """ Finds optimal action using formula above. """ if mdp.is_terminal(state): return None next_actions = mdp.get_possible_actions(state) try: from mdp_get_action_value import get_action_value except ImportError: raise ImportError("Implement get_action_value(mdp, state_values, state, action, gamma) in the file \"mdp_get_action_value.py\".") q_values = [get_action_value(mdp, state_values, state, action, gamma) for action in next_actions] optimal_action = next_actions[np.argmax(q_values)] return optimal_action
def get_optimal_action(mdp, state_values, state, gamma=0.9): """ Finds optimal action using formula above. """ if mdp.is_terminal(state): return None max = 0 argmax = 0 actions = mdp.get_possible_actions(state) for i in actions: var = get_action_value(mdp, state_values, state, i, gamma) if var > max: max = var argmax = i return argmax
def compute_vpi(mdp, policy, gamma): """ Computes V^pi(s) FOR ALL STATES under given policy. :param policy: a dict of currently chosen actions {s : a} :returns: a dict {state : V^pi(state) for all states} """ res = {} state_values = {s: 0 for s in mdp.get_all_states()} for s in policy.keys(): #res[s] = get_new_state_value(mdp, state_values, s, gamma) res[s] = get_action_value(mdp, state_values, s, policy[s], gamma) return res
def get_optimal_action_for_plot(mdp, state_values, state, gamma=0.9): """ Finds optimal action using formula above. """ if mdp.is_terminal(state): return None next_actions = mdp.get_possible_actions(state) try: from mdp_get_action_value import get_action_value except ImportError: raise ImportError( "Implement get_action_value(mdp, state_values, state, action, gamma) in the file \"mdp_get_action_value.py\".") q_values = [get_action_value(mdp, state_values, state, action, gamma) for action in next_actions] optimal_action = next_actions[np.argmax(q_values)] return optimal_action
from IPython.display import clear_output from time import sleep from mdp import has_graphviz from IPython.display import display print("Graphviz available:", has_graphviz) mdp = MDP(transition_probs, rewards, initial_state='s0') if has_graphviz: from mdp import plot_graph, plot_graph_with_state_values, \ plot_graph_optimal_strategy_and_state_values display(plot_graph(mdp)) test_Vs = {s: i for i, s in enumerate(sorted(mdp.get_all_states()))} assert np.isclose(get_action_value(mdp, test_Vs, 's2', 'a1', 0.9), 0.69) assert np.isclose(get_action_value(mdp, test_Vs, 's1', 'a0', 0.9), 3.95) test_Vs_copy = dict(test_Vs) assert np.isclose(get_new_state_value(mdp, test_Vs, 's0', 0.9), 1.8) assert np.isclose(get_new_state_value(mdp, test_Vs, 's2', 0.9), 1.08) assert test_Vs == test_Vs_copy, "please do not change state_values in get_new_state_value" # parameters gamma = 0.9 # discount for MDP num_iter = 100 # maximum iterations, excluding initialization # stop VI if new values are this close to old values (or closer) min_difference = 0.001 # initialize V(s) state_values = {s: 0 for s in mdp.get_all_states()}