def loss_fn(model_params): # policy symmetry P, r = ss.parse_model_params(n_states, n_actions, model_params) return np.sum( np.square( V(utils.softmax(P), r, pis) - V(utils.softmax(P), r, np.flip(pis, 1))))
def generate_model_iteration(): n_states, n_actions = 2, 2 mdp = utils.build_random_mdp(n_states, n_actions, 0.5) pis = utils.gen_grid_policies(7) init = rnd.standard_normal( (mdp.S * mdp.S * mdp.A + mdp.S * mdp.A) ) # needs its own init. alternatively could find init that matches value of other inits?!? vs = utils.polytope(mdp.P, mdp.r, mdp.discount, pis) plt.figure(figsize=(16, 16)) plt.scatter(vs[:, 0], vs[:, 1], c='b', s=10, alpha=0.75) lr = 0.01 pi_star = utils.solve(policy_iteration(mdp), utils.softmax(rnd.standard_normal( (mdp.S, mdp.A))))[-1] # adversarial pis apis = utils.get_deterministic_policies(mdp.S, mdp.A) apis = np.stack(apis) update_fn = model_iteration(mdp, lr, apis) params = utils.solve(update_fn, init) params = [parse_model_params(mdp.S, mdp.A, p) for p in params] vs = np.vstack([ utils.value_functional(utils.softmax(p_logits), r, pi_star, mdp.discount).T for p_logits, r in params ]) n = vs.shape[0] plt.scatter(vs[0, 0], vs[0, 1], c='g', label='PG') plt.scatter(vs[1:-1, 0], vs[1:-1, 1], c=range(n - 2), cmap='spring', s=10) plt.scatter(vs[-1, 0], vs[-1, 1], c='g', marker='x') p_logits, r = params[-1] vs = utils.polytope(utils.softmax(p_logits), r, mdp.discount, pis) plt.scatter(vs[:, 0], vs[:, 1], c='r', s=10, alpha=0.75) plt.title('Model iteration') plt.xlabel('Value of state 1') plt.ylabel('Value of state 2') # plt.show() plt.savefig('figs/model_iteration_1.png') learned_mdp = utils.MDP(mdp.S, mdp.A, utils.softmax(p_logits), r, mdp.discount, mdp.d0) pi_star_approx = utils.solve( policy_iteration(learned_mdp), utils.softmax(rnd.standard_normal((mdp.S, mdp.A))))[-1] print(pi_star_approx, '\n', pi_star)
def parameterised_policy_gradient_iteration(mdp, lr): dlogpi_dw = jacrev(lambda cores: np.log(utils.softmax(build(cores), axis=1)+1e-8)) dHdw = jacrev(lambda cores: utils.entropy(utils.softmax(build(cores)))) @jit def update_fn(cores): V = utils.value_functional(mdp.P, mdp.r, utils.softmax(build(cores), axis=1), mdp.discount) Q = utils.bellman_operator(mdp.P, mdp.r, V, mdp.discount) A = Q-V grads = [np.einsum('ijkl,ij->kl', d, A) for d in dlogpi_dw(cores)] return [c+lr*utils.clip_by_norm(g, 100)+1e-6*dH for c, g, dH in zip(cores, grads, dHdw(cores))] return update_fn
def lmdp_decoder(u, P, lr=10): """ Given optimal control dynamics. Optimise a softmax parameterisation of the policy. That yields those same dynamics. """ # NOTE is there a way to solve this using linear equations?! # W = log(P_pi) # = sum_a log(P[a]) + log(pi[a]) # M = log(u) # UW - UM = 0 # U(W-M) = 0, W = M = sum_a log(P[a]) + log(pi[a]) # 0 = sum_a log(P[a]) + log(pi[a]) - M def loss(pi_logits): pi = utils.softmax(pi_logits) # P_pi(s'|s) = \sum_a pi(a|s)p(s'|s, a) P_pi = np.einsum('ijk,jk->ij', P, pi) return np.sum(np.multiply(u, np.log(u / P_pi))) # KL dLdw = jit(grad(loss)) def update_fn(w): return w - lr * dLdw(w) init = rnd.standard_normal((P.shape[0], P.shape[-1])) pi_star_logits = utils.solve(update_fn, init)[-1] return utils.softmax(pi_star_logits)
def lmdp_option_decoder(u, P, lr=1, k=5): """ Given optimal control dynamics. Optimise a softmax parameterisation of the policy. That yields those same dynamics. """ n_states = P.shape[0] n_actions = P.shape[-1] # the augmented transition fn. [n_states, n_states, n_options] P_options = option_transition_fn(P, k) def loss(option_logits): options = utils.softmax(option_logits) # P_pi(s'|s) = \sum_w pi(w|s)p(s'|s, w) P_pi = np.einsum('ijk,jk->ij', P_options, options) return np.sum(np.multiply(u, np.log(u / P_pi))) # KL dLdw = jit(grad(loss)) def update_fn(w): return w - lr * dLdw(w) n_options = sum([n_actions**(i + 1) for i in range(k)]) print('N options: {}'.format(n_options)) init = rnd.standard_normal((P.shape[0], n_options)) pi_star_logits = utils.solve(update_fn, init)[-1] return utils.softmax(pi_star_logits)
def graph_PG(): # ffmpeg -framerate 10 -start_number 0 -i %d.png -c:v libx264 -r 30 -pix_fmt yuv420p out.mp4 n_states = 6 n_actions = 4 det_pis = utils.get_deterministic_policies(n_states, n_actions) print('n pis: {}'.format(len(det_pis))) mdp = utils.build_random_mdp(n_states, n_actions, 0.9) A = graph.mdp_topology(det_pis) G = nx.from_numpy_array(A) pos = nx.spring_layout(G, iterations=200) basis = graph.construct_mdp_basis(det_pis, mdp) init_logits = np.random.standard_normal((n_states, n_actions)) init_v = utils.value_functional(mdp.P, mdp.r, utils.softmax(init_logits), mdp.discount).squeeze() a = graph.sparse_coeffs(basis, init_v, lr=0.1) print('\nSolving PG') pis = utils.solve(search_spaces.policy_gradient_iteration_logits(mdp, 0.1), init_logits) print("\n{} policies to vis".format(len(pis))) n = len(pis) # pis = pis[::n//100] pis = pis[0:20] for i, pi in enumerate(pis[:-1]): print('Iteration: {}'.format(i)) v = utils.value_functional(mdp.P, mdp.r, pi, mdp.discount).squeeze() a = graph.sparse_coeffs(basis, v, lr=0.1, a_init=a) plt.figure() nx.draw(G, pos, node_color=a) # plt.show() plt.savefig('figs/pg_graphs/{}.png'.format(i)) plt.close()
def graph_PI(): n_states = 10 n_actions = 2 det_pis = utils.get_deterministic_policies(n_states, n_actions) print('n pis: {}'.format(len(det_pis))) mdp = utils.build_random_sparse_mdp(n_states, n_actions, 0.5) A = graph.mdp_topology(det_pis) G = nx.from_numpy_array(A) pos = nx.spring_layout(G, iterations=200) basis = graph.construct_mdp_basis(det_pis, mdp) init_pi = utils.softmax(np.random.standard_normal((n_states, n_actions))) init_v = utils.value_functional(mdp.P, mdp.r, init_pi, mdp.discount).squeeze() a = graph.sparse_coeffs(basis, init_v, lr=0.1) pis = utils.solve(search_spaces.policy_iteration(mdp), init_pi) print("\n{} policies to vis".format(len(pis))) for i, pi in enumerate(pis[:-1]): print('Iteration: {}'.format(i)) v = utils.value_functional(mdp.P, mdp.r, pi, mdp.discount).squeeze() a = graph.sparse_coeffs(basis, v, lr=0.1, a_init=a) plt.figure(figsize=(16,16)) nx.draw(G, pos, node_color=a, node_size=150) # plt.show() plt.savefig('figs/pi_graphs/{}.png'.format(i)) plt.close()
def generate_pg(mdp, c, lr=0.01): init_pi = utils.random_policy(mdp.S,mdp.A) init_logit = np.log(init_pi) logits = utils.solve(ss.policy_gradient_iteration_logits(mdp, lr), init_logit) vs = np.stack([utils.value_functional(mdp.P, mdp.r, utils.softmax(logit), mdp.discount) for logit in logits])[:,:,0] n = vs.shape[0] plt.scatter(vs[0, 0], vs[0, 1], c=c, s=30, label='{}'.format(n)) plt.scatter(vs[1:-1, 0], vs[1:-1, 1], c=range(n-2), cmap='viridis', s=10) plt.scatter(vs[-1, 0], vs[-1, 1], c='m', marker='x')
def policy_gradient_iteration_logits(mdp, lr): # this doesnt seem to behave nicely in larger state spaces!? # d/dlogits V = E_{\pi}[V] = E[V . d/dlogit log \pi] # dlogpi_dlogit = jacrev(lambda logits: np.log(utils.softmax(logits)+1e-8)) dHdlogit = grad(lambda logits: utils.entropy(utils.softmax(logits))) dVdlogit = grad(lambda logits: np.sum(utils.value_functional(mdp.P, mdp.r, utils.softmax(logits), mdp.discount))) @jit def update_fn(logits): # NOTE this is actually soft A2C. # V = utils.value_functional(mdp.P, mdp.r, utils.softmax(logits), mdp.discount) # Q = utils.bellman_operator(mdp.P, mdp.r, V, mdp.discount) # A = Q-V # g = np.einsum('ijkl,ij->kl', dlogpi_dlogit(logits), A) g = dVdlogit(logits) return logits + lr * utils.clip_by_norm(g, 500) + 1e-8*dHdlogit(logits) return update_fn
def generate_model_cs(): """ Compare using all deterministic policies versus fewer mixed policies. Starts to get interesting in higher dims? """ n_states = 32 n_actions = 2 lr = 0.01 k = 64 mdp = utils.build_random_mdp(n_states, n_actions, 0.5) init = rnd.standard_normal((mdp.S * mdp.S * mdp.A + mdp.S * mdp.A)) pi_star = utils.solve(policy_iteration(mdp), utils.softmax(rnd.standard_normal( (mdp.S, mdp.A))))[-1] print('pi_star\n', pi_star) # adversarial pis # apis = utils.get_deterministic_policies(mdp.S, mdp.A) apis = np.stack([utils.random_det_policy(mdp.S, mdp.A) for _ in range(k)]) update_fn = model_iteration(mdp, lr, apis) params = utils.solve(update_fn, init) p_logits, r = parse_model_params(mdp.S, mdp.A, params[-1]) error = np.mean( (utils.value_functional(mdp.P, mdp.r, pi_star, mdp.discount) - utils.value_functional(utils.softmax(p_logits), r, pi_star, mdp.discount))**2) print('\n', error) new_mdp = utils.MDP(mdp.S, mdp.A, utils.softmax(p_logits), r, mdp.discount, mdp.d0) pi_star = utils.solve(policy_iteration(new_mdp), utils.softmax(rnd.standard_normal( (mdp.S, mdp.A))))[-1] print(pi_star) apis = np.stack([utils.random_policy(mdp.S, mdp.A) for _ in range(k)]) update_fn = model_iteration(mdp, lr, apis) params = utils.solve(update_fn, init) p_logits, r = parse_model_params(mdp.S, mdp.A, params[-1]) error = np.mean( (utils.value_functional(mdp.P, mdp.r, pi_star, mdp.discount) - utils.value_functional(utils.softmax(p_logits), r, pi_star, mdp.discount))**2) print('\n', error) new_mdp = utils.MDP(mdp.S, mdp.A, utils.softmax(p_logits), r, mdp.discount, mdp.d0) pi_star = utils.solve(policy_iteration(new_mdp), utils.softmax(rnd.standard_normal( (mdp.S, mdp.A))))[-1] print(pi_star)
def test_density(): mdp = utils.build_random_mdp(2, 2, 0.9) pi = utils.softmax(rnd.standard_normal((2,2)), axis=1) p_V = density_value_functional(0.1, mdp.P, mdp.r, pi, 0.9) print(p_V)
def sparse_loss(x): a = utils.softmax(x) # convex combination return mse(np.dot(basis, a), b) + gamma * utils.entropy(a)
def loss_fn(params, pis): p_logits, r = parse_model_params(n_states, n_actions, params) return np.sum(value(utils.softmax(p_logits), r, pis)**2)
def loss(pi_logits): pi = utils.softmax(pi_logits) # P_pi(s'|s) = \sum_a pi(a|s)p(s'|s, a) P_pi = np.einsum('ijk,jk->ij', P, pi) return np.sum(np.multiply(u, np.log(u / P_pi))) # KL
def loss_fn(params): p_logits, r = parse_model_params(mdp.S, mdp.A, params) return np.sum((V_true(pis) - V_guess(utils.softmax(p_logits), r, pis))**2)
import numpy as np import mdp.utils as utils from mdp.search_spaces import * def clip_solver_traj(traj): if np.isclose(traj[-1], traj[-2], 1e-8).all(): return traj[:-1] else: return traj mdp = utils.build_random_mdp(2, 2, 0.5) init = utils.softmax(rnd.standard_normal((mdp.S, mdp.A)), axis=1) pi_traj = clip_solver_traj(utils.solve(policy_iteration(mdp), init)) print(pi_traj)
def loss(option_logits): options = utils.softmax(option_logits) # P_pi(s'|s) = \sum_w pi(w|s)p(s'|s, w) P_pi = np.einsum('ijk,jk->ij', P_options, options) return np.sum(np.multiply(u, np.log(u / P_pi))) # KL
def simple_test(): """ Explore how the unconstrained dynamics in a simple setting. """ # What about when p(s'| s) = 0, is not possible under the true dynamics?! r = np.array([[1, 0], [0, 0]]) # Indexed by [s' x s x a] # ensure we have a distribution over s' p000 = 1 p100 = 1 - p000 p001 = 0 p101 = 1 - p001 p010 = 0 p110 = 1 - p010 p011 = 1 p111 = 1 - p011 P = np.array([ [[p000, p001], [p010, p011]], [[p100, p101], [p110, p111]], ]) # BUG ??? only seems to work for deterministic transitions!? # oh, this is because deterministic transitions satisfy the row rank requirement??! # P = np.random.random((2, 2, 2)) # P = P/np.sum(P, axis=0) # a distribution over future states assert np.isclose(np.sum(P, axis=0), np.ones((2, 2))).all() pi = utils.softmax(r, axis=1) # exp Q vals w gamma = 0 # a distribution over actions assert np.isclose(np.sum(pi, axis=1), np.ones((2, ))).all() p, q = mdp_encoder(P, r) print('q', q) print('p', p) print('P', P) P_pi = np.einsum('ijk,jk->ij', P, pi) print('P_pi', P_pi) # the unconstrained dynamics with deterministic transitions, # are the same was using a gamma = 0 boltzman Q vals print("exp(r) is close to p? {}".format( np.isclose(p, P_pi, atol=1e-4).all())) # r(s, a) = q(s) - KL(P(. | s, a) || p(. | s)) ce = numpy.zeros((2, 2)) for j in range(2): for k in range(2): # actions ce[j, k] = CE(P[:, j, k], p[:, j]) r_approx = q[:, np.newaxis] + ce print(np.around(r, 3)) print(np.around(r_approx, 3)) print('r ~= q - CE(P || p): {}'.format( np.isclose(r, r_approx, atol=1e-2).all())) print('\n\n')
def update_fn(cores): V = utils.value_functional(mdp.P, mdp.r, utils.softmax(build(cores), axis=1), mdp.discount) Q = utils.bellman_operator(mdp.P, mdp.r, V, mdp.discount) A = Q-V grads = [np.einsum('ijkl,ij->kl', d, A) for d in dlogpi_dw(cores)] return [c+lr*utils.clip_by_norm(g, 100)+1e-6*dH for c, g, dH in zip(cores, grads, dHdw(cores))]