def generate_model_cs(): """ Compare using all deterministic policies versus fewer mixed policies. Starts to get interesting in higher dims? """ n_states = 32 n_actions = 2 lr = 0.01 k = 64 mdp = utils.build_random_mdp(n_states, n_actions, 0.5) init = rnd.standard_normal((mdp.S * mdp.S * mdp.A + mdp.S * mdp.A)) pi_star = utils.solve(policy_iteration(mdp), utils.softmax(rnd.standard_normal( (mdp.S, mdp.A))))[-1] print('pi_star\n', pi_star) # adversarial pis # apis = utils.get_deterministic_policies(mdp.S, mdp.A) apis = np.stack([utils.random_det_policy(mdp.S, mdp.A) for _ in range(k)]) update_fn = model_iteration(mdp, lr, apis) params = utils.solve(update_fn, init) p_logits, r = parse_model_params(mdp.S, mdp.A, params[-1]) error = np.mean( (utils.value_functional(mdp.P, mdp.r, pi_star, mdp.discount) - utils.value_functional(utils.softmax(p_logits), r, pi_star, mdp.discount))**2) print('\n', error) new_mdp = utils.MDP(mdp.S, mdp.A, utils.softmax(p_logits), r, mdp.discount, mdp.d0) pi_star = utils.solve(policy_iteration(new_mdp), utils.softmax(rnd.standard_normal( (mdp.S, mdp.A))))[-1] print(pi_star) apis = np.stack([utils.random_policy(mdp.S, mdp.A) for _ in range(k)]) update_fn = model_iteration(mdp, lr, apis) params = utils.solve(update_fn, init) p_logits, r = parse_model_params(mdp.S, mdp.A, params[-1]) error = np.mean( (utils.value_functional(mdp.P, mdp.r, pi_star, mdp.discount) - utils.value_functional(utils.softmax(p_logits), r, pi_star, mdp.discount))**2) print('\n', error) new_mdp = utils.MDP(mdp.S, mdp.A, utils.softmax(p_logits), r, mdp.discount, mdp.d0) pi_star = utils.solve(policy_iteration(new_mdp), utils.softmax(rnd.standard_normal( (mdp.S, mdp.A))))[-1] print(pi_star)
def generate_model_iteration(): n_states, n_actions = 2, 2 mdp = utils.build_random_mdp(n_states, n_actions, 0.5) pis = utils.gen_grid_policies(7) init = rnd.standard_normal( (mdp.S * mdp.S * mdp.A + mdp.S * mdp.A) ) # needs its own init. alternatively could find init that matches value of other inits?!? vs = utils.polytope(mdp.P, mdp.r, mdp.discount, pis) plt.figure(figsize=(16, 16)) plt.scatter(vs[:, 0], vs[:, 1], c='b', s=10, alpha=0.75) lr = 0.01 pi_star = utils.solve(policy_iteration(mdp), utils.softmax(rnd.standard_normal( (mdp.S, mdp.A))))[-1] # adversarial pis apis = utils.get_deterministic_policies(mdp.S, mdp.A) apis = np.stack(apis) update_fn = model_iteration(mdp, lr, apis) params = utils.solve(update_fn, init) params = [parse_model_params(mdp.S, mdp.A, p) for p in params] vs = np.vstack([ utils.value_functional(utils.softmax(p_logits), r, pi_star, mdp.discount).T for p_logits, r in params ]) n = vs.shape[0] plt.scatter(vs[0, 0], vs[0, 1], c='g', label='PG') plt.scatter(vs[1:-1, 0], vs[1:-1, 1], c=range(n - 2), cmap='spring', s=10) plt.scatter(vs[-1, 0], vs[-1, 1], c='g', marker='x') p_logits, r = params[-1] vs = utils.polytope(utils.softmax(p_logits), r, mdp.discount, pis) plt.scatter(vs[:, 0], vs[:, 1], c='r', s=10, alpha=0.75) plt.title('Model iteration') plt.xlabel('Value of state 1') plt.ylabel('Value of state 2') # plt.show() plt.savefig('figs/model_iteration_1.png') learned_mdp = utils.MDP(mdp.S, mdp.A, utils.softmax(p_logits), r, mdp.discount, mdp.d0) pi_star_approx = utils.solve( policy_iteration(learned_mdp), utils.softmax(rnd.standard_normal((mdp.S, mdp.A))))[-1] print(pi_star_approx, '\n', pi_star)
def graph_PI(): n_states = 10 n_actions = 2 det_pis = utils.get_deterministic_policies(n_states, n_actions) print('n pis: {}'.format(len(det_pis))) mdp = utils.build_random_sparse_mdp(n_states, n_actions, 0.5) A = graph.mdp_topology(det_pis) G = nx.from_numpy_array(A) pos = nx.spring_layout(G, iterations=200) basis = graph.construct_mdp_basis(det_pis, mdp) init_pi = utils.softmax(np.random.standard_normal((n_states, n_actions))) init_v = utils.value_functional(mdp.P, mdp.r, init_pi, mdp.discount).squeeze() a = graph.sparse_coeffs(basis, init_v, lr=0.1) pis = utils.solve(search_spaces.policy_iteration(mdp), init_pi) print("\n{} policies to vis".format(len(pis))) for i, pi in enumerate(pis[:-1]): print('Iteration: {}'.format(i)) v = utils.value_functional(mdp.P, mdp.r, pi, mdp.discount).squeeze() a = graph.sparse_coeffs(basis, v, lr=0.1, a_init=a) plt.figure(figsize=(16,16)) nx.draw(G, pos, node_color=a, node_size=150) # plt.show() plt.savefig('figs/pi_graphs/{}.png'.format(i)) plt.close()
def sparse_coeffs(basis, b, gamma=1e-6, lr=1e-1, a_init=None): """ Want x s.t. b ~= basis . x min || basis . x - b ||_2^2 + gamma * ||x||_1 """ assert basis.shape[0] == b.shape[0] def sparse_loss(x): a = utils.softmax(x) # convex combination return mse(np.dot(basis, a), b) + gamma * utils.entropy(a) dLdx = grad(sparse_loss) @jit def update_fn(x): g = dLdx(x) # print(x) return x - lr * g if a_init is None: init = 1e-3*rnd.standard_normal((basis.shape[1], )) else: init = a_init init = (init, np.zeros_like(init)) output = utils.solve(search_spaces.momentum_bundler(update_fn, 0.9), init) a_s, mom_s = zip(*output) return a_s[-1]
def lmdp_solver(p, q, discount): """ Solves z = QPz^a Args: p (np.ndarray): [n_states x n_states]. The unconditioned dynamics q (np.ndarray): [n_states x 1]. The state rewards Returns: (np.ndarray): [n_states x n_states].the optimal control (np.ndarray): [n_states x 1]. the value of the optimal policy """ # BUG doesnt work for large discounts: 0.999. # Evaluate # Solve z = QPz init = np.ones((p.shape[-1], 1)) update_fn = lambda z: linear_bellman_operator(p, q, z, discount) z = utils.solve(update_fn, init)[-1].squeeze() v = np.log(z) # Calculate the optimal control # G(x) = sum_x' p(x' | x) z(x') G = np.einsum('ij,i->j', p, z) # u*(x' | x) = p(x' | x) z(x') / G[z](x) u = p * z[:, np.newaxis] / G[np.newaxis, :] return u, v
def lmdp_decoder(u, P, lr=10): """ Given optimal control dynamics. Optimise a softmax parameterisation of the policy. That yields those same dynamics. """ # NOTE is there a way to solve this using linear equations?! # W = log(P_pi) # = sum_a log(P[a]) + log(pi[a]) # M = log(u) # UW - UM = 0 # U(W-M) = 0, W = M = sum_a log(P[a]) + log(pi[a]) # 0 = sum_a log(P[a]) + log(pi[a]) - M def loss(pi_logits): pi = utils.softmax(pi_logits) # P_pi(s'|s) = \sum_a pi(a|s)p(s'|s, a) P_pi = np.einsum('ijk,jk->ij', P, pi) return np.sum(np.multiply(u, np.log(u / P_pi))) # KL dLdw = jit(grad(loss)) def update_fn(w): return w - lr * dLdw(w) init = rnd.standard_normal((P.shape[0], P.shape[-1])) pi_star_logits = utils.solve(update_fn, init)[-1] return utils.softmax(pi_star_logits)
def graph_PG(): # ffmpeg -framerate 10 -start_number 0 -i %d.png -c:v libx264 -r 30 -pix_fmt yuv420p out.mp4 n_states = 6 n_actions = 4 det_pis = utils.get_deterministic_policies(n_states, n_actions) print('n pis: {}'.format(len(det_pis))) mdp = utils.build_random_mdp(n_states, n_actions, 0.9) A = graph.mdp_topology(det_pis) G = nx.from_numpy_array(A) pos = nx.spring_layout(G, iterations=200) basis = graph.construct_mdp_basis(det_pis, mdp) init_logits = np.random.standard_normal((n_states, n_actions)) init_v = utils.value_functional(mdp.P, mdp.r, utils.softmax(init_logits), mdp.discount).squeeze() a = graph.sparse_coeffs(basis, init_v, lr=0.1) print('\nSolving PG') pis = utils.solve(search_spaces.policy_gradient_iteration_logits(mdp, 0.1), init_logits) print("\n{} policies to vis".format(len(pis))) n = len(pis) # pis = pis[::n//100] pis = pis[0:20] for i, pi in enumerate(pis[:-1]): print('Iteration: {}'.format(i)) v = utils.value_functional(mdp.P, mdp.r, pi, mdp.discount).squeeze() a = graph.sparse_coeffs(basis, v, lr=0.1, a_init=a) plt.figure() nx.draw(G, pos, node_color=a) # plt.show() plt.savefig('figs/pg_graphs/{}.png'.format(i)) plt.close()
def lmdp_option_decoder(u, P, lr=1, k=5): """ Given optimal control dynamics. Optimise a softmax parameterisation of the policy. That yields those same dynamics. """ n_states = P.shape[0] n_actions = P.shape[-1] # the augmented transition fn. [n_states, n_states, n_options] P_options = option_transition_fn(P, k) def loss(option_logits): options = utils.softmax(option_logits) # P_pi(s'|s) = \sum_w pi(w|s)p(s'|s, w) P_pi = np.einsum('ijk,jk->ij', P_options, options) return np.sum(np.multiply(u, np.log(u / P_pi))) # KL dLdw = jit(grad(loss)) def update_fn(w): return w - lr * dLdw(w) n_options = sum([n_actions**(i + 1) for i in range(k)]) print('N options: {}'.format(n_options)) init = rnd.standard_normal((P.shape[0], n_options)) pi_star_logits = utils.solve(update_fn, init)[-1] return utils.softmax(pi_star_logits)
def Q(init, M, f): # solve V_init = utils.value_functional(M.P, M.r, init, M.discount) Q_init = utils.bellman_operator(M.P, M.r, V_init, M.discount) Q_star = utils.solve(ss.q_learning(M, 0.01), Q_init)[-1] # lift return np.dot(f.T, np.max(Q_star, axis=1, keepdims=True))
def generate_vi(mdp, c, lr=0.1): init_pi = utils.random_policy(mdp.S,mdp.A) init_v = utils.value_functional(mdp.P, mdp.r, init_pi, mdp.discount) vs = np.stack(utils.solve(ss.value_iteration(mdp, lr), init_v))[:,:,0] n = vs.shape[0] plt.scatter(vs[0, 0], vs[0, 1], c=c, s=30, label='{}'.format(n)) plt.scatter(vs[1:-1, 0], vs[1:-1, 1], c=range(n-2), cmap='viridis', s=10) plt.scatter(vs[-1, 0], vs[-1, 1], c='m', marker='x')
def VI(init, M, f): # solve V_init = utils.value_functional(M.P, M.r, init, M.discount) V_star = utils.solve(ss.value_iteration(M, 0.01), V_init)[-1] # lift return np.dot(f.T, V_star)
def generate_pg(mdp, c, lr=0.01): init_pi = utils.random_policy(mdp.S,mdp.A) init_logit = np.log(init_pi) logits = utils.solve(ss.policy_gradient_iteration_logits(mdp, lr), init_logit) vs = np.stack([utils.value_functional(mdp.P, mdp.r, utils.softmax(logit), mdp.discount) for logit in logits])[:,:,0] n = vs.shape[0] plt.scatter(vs[0, 0], vs[0, 1], c=c, s=30, label='{}'.format(n)) plt.scatter(vs[1:-1, 0], vs[1:-1, 1], c=range(n-2), cmap='viridis', s=10) plt.scatter(vs[-1, 0], vs[-1, 1], c='m', marker='x')
def value_iteration(mdp, pis, lr): trajs = [] for pi in pis: init_V = utils.value_functional(mdp.P, mdp.r, pi, mdp.discount) traj = utils.solve(ss.value_iteration(mdp, lr), init_V) v_star = traj[-1] trajs.append(traj) return trajs
def generate_pi(mdp, c): init_pi = utils.random_policy(mdp.S,mdp.A) pis = utils.solve(ss.policy_iteration(mdp), init_pi) vs = np.stack([utils.value_functional(mdp.P, mdp.r, pi, mdp.discount) for pi in pis])[:,:,0] n = vs.shape[0] plt.scatter(vs[0, 0], vs[0, 1], c=c, s=30, label='{}'.format(n-2)) plt.scatter(vs[1:-1, 0], vs[1:-1, 1], c=range(n-2), cmap='viridis', s=10) plt.scatter(vs[-1, 0], vs[-1, 1], c='m', marker='x') for i in range(len(vs)-2): dv = 0.1*(vs[i+1, :] - vs[i, :]) plt.arrow(vs[i, 0], vs[i, 1], dv[0], dv[1], color=c, alpha=0.5, width=0.005)
def value_iteration(mdp, pis): lens, pi_stars = [], [] for pi in pis: init_V = utils.value_functional(mdp.P, mdp.r, pi, mdp.discount) pi_traj = utils.solve(ss.value_iteration(mdp, 0.01), init_V) pi_star = pi_traj[-1] pi_stars.append(pi_star) lens.append(len(pi_traj)) return lens, pi_stars
def policy_gradient(mdp, pis): lens, pi_stars = [], [] for pi in pis: pi_traj = utils.solve(ss.policy_gradient_iteration_logits(mdp, 0.01), np.log(pi + 1e-8)) pi_star = pi_traj[-1] pi_stars.append(pi_star) lens.append(len(pi_traj)) return lens, pi_stars
def generate_cvi(): print('\nRunning PVI vs VI') n_states, n_actions = 2, 2 mdp = utils.build_random_mdp(n_states, n_actions, 0.5) fn = ss.complex_value_iteration(mdp, 0.01) Q = rnd.standard_normal((n_states, 1)) + 1j*rnd.standard_normal((n_states, 1)) results = utils.solve(fn, Q) print(results)
def policy_iteration(mdp, pis): # pi_star = utils.solve(ss.policy_iteration(mdp), pis[0])[-1] lens, pi_stars = [], [] for pi in pis: pi_traj = clip_solver_traj(utils.solve(ss.policy_iteration(mdp), pi)) pi_star = pi_traj[-1] pi_stars.append(pi_star) lens.append(len(pi_traj)) return lens, pi_stars
def approximate(v, cores, lr=1e-2, activation_fn=lambda x: x): """ cores = random_parameterised_matrix(2, 1, d_hidden=8, n_hidden=4) v = rnd.standard_normal((2,1)) cores_ = approximate(v, cores) print(v, '\n',build(cores_)) """ loss = lambda cores: np.sum(np.square(v - activation_fn(build(cores)))) dl2dc = grad(loss) l2_update_fn = lambda cores: [c - lr*g for g, c in zip(dl2dc(cores), cores)] init = (cores, [np.zeros_like(c) for c in cores]) final_variables, momentum_var = utils.solve(momentum_bundler(l2_update_fn, 0.9), init)[-1] return final_variables
def mom_value_iteration(mdp, pis): lens, pi_stars = [], [] for pi in pis: init_V = utils.value_functional(mdp.P, mdp.r, pi, mdp.discount) pi_traj = utils.solve( ss.momentum_bundler(ss.value_iteration(mdp, 0.01), 0.9), (init_V, np.zeros_like(init_V))) pi_star, _ = pi_traj[-1] pi_stars.append(pi_star) lens.append(len(pi_traj)) return lens, pi_stars
def policy_gradient(args): P, r, discount, d0, pis, lr = args mdp = utils.MDP(r.shape[0], r.shape[1], P, r, discount, d0) lens, pi_stars = [], [] for pi in pis: pi_traj = utils.solve(ss.policy_gradient_iteration_logits(mdp, lr), np.log(pi + 1e-8)) pi_star = pi_traj[-1] pi_stars.append(pi_star) lens.append(len(pi_traj)) return lens, pi_stars
def param_policy_gradient(mdp, pis): lens, pi_stars = [], [] core_init = ss.random_parameterised_matrix(2, 2, 32, 8) for pi in pis: core_init = ss.approximate(pi, core_init, activation_fn=utils.softmax) pi_traj = utils.solve( ss.parameterised_policy_gradient_iteration(mdp, 0.01 / len(core_init)), core_init) pi_star = pi_traj[-1] pi_stars.append(pi_star) lens.append(len(pi_traj)) return lens, pi_stars
def onoffpolicy_abstraction(mdp, pis): tol = 0.01 init = np.random.random((mdp.S, mdp.A)) init = init / np.sum(init, axis=1, keepdims=True) # ### all policy abstraction # # n x |S| x |A| # Qs = np.stack([utils.bellman_operator(mdp.P, mdp.r, utils.value_functional(mdp.P, mdp.r, pi, mdp.discount), mdp.discount) for pi in pis], axis=0) # similar_states = np.sum(np.sum(np.abs(Qs[:, :, None, :] - Qs[:, None, :, :]), axis=3), axis=0) # |S| x |S| # all_idx, all_abstracted_mdp, all_f = abs.build_state_abstraction(similar_states, mdp) ### optimal policy abstraction pi_star = utils.solve(ss.policy_iteration(mdp), np.log(init))[-1] Q_star = utils.bellman_operator( mdp.P, mdp.r, utils.value_functional(mdp.P, mdp.r, pi_star, mdp.discount), mdp.discount) # similar_states = np.sum(np.abs(Q_star[:, None, :] - Q_star[None, :, :]), axis=-1) # |S| x |S|. preserves optimal policy's value (for all actions) # similar_states = np.abs(np.max(Q_star[:, None, :],axis=-1) - np.max(Q_star[None, :, :],axis=-1)) # |S| x |S|. preserves optimal action's value # V = utils.value_functional(mdp.P, mdp.r, init, mdp.discount) similar_states = np.abs(V[None, :, :] - V[:, None, :])[:, :, 0] optimal_idx, optimal_abstracted_mdp, optimal_f = abs.build_state_abstraction( similar_states, mdp, tol) mdps = [mdp, optimal_abstracted_mdp] names = ['ground', 'optimal_abstracted_mdp'] solvers = [abs.Q, abs.SARSA, abs.VI] lifts = [np.eye(mdp.S), optimal_f] idxs = [range(mdp.S), optimal_idx] # if all_f.shape[0] == optimal_f.shape[0]: # raise ValueError('Abstractions are the same so we probs wont see any difference...') print('\nAbstraction:', optimal_f.shape) truth = abs.PI(init, mdp, np.eye(mdp.S)) results = [] for n, M, idx, f in zip(names, mdps, idxs, lifts): for solve in solvers: err = np.max(np.abs(truth - solve(init[idx, :], M, f))) results.append((n, solve.__name__, err)) return results
def mom_param_value_iteration(mdp, pis): lens, pi_stars = [], [] core_init = ss.random_parameterised_matrix(2, 2, 32, 4) for pi in pis: init_V = utils.value_functional(mdp.P, mdp.r, pi, mdp.discount) core_init = ss.approximate(init_V, core_init) params = utils.solve( ss.momentum_bundler( ss.parameterised_value_iteration(mdp, 0.01 / len(core_init)), 0.8), (core_init, [np.zeros_like(c) for c in core_init])) pi_star, _ = params[-1] pi_stars.append(pi_star) lens.append(len(params)) return lens, pi_stars
def param_value_iteration(mdp, pis): # hypothesis. we are going to see some weirdness in the mom partitions. # oscillations will depend on shape of the polytope?!? lens, pi_stars = [], [] core_init = ss.random_parameterised_matrix(2, 2, 32, 4) for pi in pis: init_V = utils.value_functional(mdp.P, mdp.r, pi, mdp.discount) core_init = ss.approximate(init_V, core_init) params = utils.solve( ss.parameterised_value_iteration(mdp, 0.01 / len(core_init)), core_init) pi_star = params[-1] pi_stars.append(pi_star) lens.append(len(params)) return lens, pi_stars
def mom_policy_gradient(mdp, pis): lens, pi_stars = [], [] for pi in pis: try: pi_traj = utils.solve( ss.momentum_bundler( ss.policy_gradient_iteration_logits(mdp, 0.01), 0.9), (np.log(pi + 1e-8), np.zeros_like(pi))) pi_star, _ = pi_traj[-1] L = len(pi_traj) except ValueError: pi_star = pis[0] L = 10000 pi_stars.append(pi_star) lens.append(L) return lens, pi_stars
def compare_mdp_lmdp(): n_states, n_actions = 2, 2 mdp = utils.build_random_mdp(n_states, n_actions, 0.9) pis = utils.gen_grid_policies(7) vs = utils.polytope(mdp.P, mdp.r, mdp.discount, pis) plt.figure(figsize=(16, 16)) plt.scatter(vs[:, 0], vs[:, 1], s=10, alpha=0.75) # solve via LMDPs p, q = lmdps.mdp_encoder(mdp.P, mdp.r) u, v = lmdps.lmdp_solver(p, q, mdp.discount) pi_u_star = lmdps.lmdp_decoder(u, mdp.P) pi_p = lmdps.lmdp_decoder(p, mdp.P) # solve MDP init = np.random.standard_normal((n_states, n_actions)) pi_star = utils.solve(search_spaces.policy_iteration(mdp), init)[-1] # pi_star = onehot(np.argmax(qs, axis=1), n_actions) # evaluate both policies. v_star = utils.value_functional(mdp.P, mdp.r, pi_star, mdp.discount) v_u_star = utils.value_functional(mdp.P, mdp.r, pi_u_star, mdp.discount) v_p = utils.value_functional(mdp.P, mdp.r, pi_p, mdp.discount) plt.scatter(v_star[0, 0], v_star[1, 0], c='m', alpha=0.5, marker='x', label='mdp') plt.scatter(v_u_star[0, 0], v_u_star[1, 0], c='g', alpha=0.5, marker='x', label='lmdp') plt.scatter(v_p[0, 0], v_p[1, 0], c='k', marker='x', alpha=0.5, label='p') plt.legend() plt.show()
def compare_acc(): n_states, n_actions = 2, 2 lmdp = [] lmdp_rnd = [] for _ in range(10): mdp = utils.build_random_mdp(n_states, n_actions, 0.5) # solve via LMDPs p, q = lmdps.mdp_encoder(mdp.P, mdp.r) u, v = lmdps.lmdp_solver(p, q, mdp.discount) pi_u_star = lmdps.lmdp_decoder(u, mdp.P) # solve MDP init = np.random.standard_normal((n_states, n_actions)) pi_star = utils.solve(search_spaces.policy_iteration(mdp), init)[-1] # solve via LMDPs # with p set to the random dynamics p, q = lmdps.mdp_encoder(mdp.P, mdp.r) p = np.einsum('ijk,jk->ij', mdp.P, np.ones((n_states, n_actions)) / n_actions) # q = np.max(mdp.r, axis=1, keepdims=True) u, v = lmdps.lmdp_solver(p, q, mdp.discount) pi_u_star_random = lmdps.lmdp_decoder(u, mdp.P) # evaluate both policies. v_star = utils.value_functional(mdp.P, mdp.r, pi_star, mdp.discount) v_u_star = utils.value_functional(mdp.P, mdp.r, pi_u_star, mdp.discount) v_u_star_random = utils.value_functional(mdp.P, mdp.r, pi_u_star_random, mdp.discount) lmdp.append(np.isclose(v_star, v_u_star, 1e-3).all()) lmdp_rnd.append(np.isclose(v_star, v_u_star_random, 1e-3).all()) print([np.sum(lmdp), np.sum(lmdp_rnd)]) plt.bar(range(2), [np.sum(lmdp), np.sum(lmdp_rnd)]) plt.show()
def mdp_lmdp_optimality(): n_states, n_actions = 2, 2 n = 5 plt.figure(figsize=(8, 16)) plt.title('Optimal control (LMDP) vs optimal policy (MDP)') for i in range(n): mdp = utils.build_random_mdp(n_states, n_actions, 0.5) # solve via LMDPs p, q = lmdps.mdp_encoder(mdp.P, mdp.r) u, v = lmdps.lmdp_solver(p, q, mdp.discount) init = np.random.standard_normal((n_states, n_actions)) pi_star = utils.solve(search_spaces.policy_iteration(mdp), init)[-1] P_pi_star = np.einsum('ijk,jk->ij', mdp.P, pi_star) plt.subplot(n, 2, 2 * i + 1) plt.imshow(u) plt.subplot(n, 2, 2 * i + 2) plt.imshow(P_pi_star) plt.savefig('figs/lmdp_mdp_optimal_dynamics.png') plt.show()
def find_symmetric_mdp(n_states, n_actions, discount, lr=1e-2): """ Approximately find a mdp with ??? symmetry """ model_init = rnd.standard_normal(n_states * n_states * n_actions + n_states * n_actions) pis = utils.get_deterministic_policies(n_states, n_actions) # pis = [utils.random_policy(n_states, n_actions) for _ in range(100)] pis = np.stack(pis) # print(pis.shape) V = vmap(lambda P, r, pi: utils.value_functional(P, r, pi, discount), in_axes=(None, None, 0)) def loss_fn(model_params): # policy symmetry P, r = ss.parse_model_params(n_states, n_actions, model_params) return np.sum( np.square( V(utils.softmax(P), r, pis) - V(utils.softmax(P), r, np.flip(pis, 1)))) # def loss_fn(model_params): # # value symmetry # P, r = ss.parse_model_params(n_states, n_actions, model_params) # vals = V(utils.softmax(P), r, pis) # n = n_states//2 # return np.sum(np.square(vals[:, :n] - vals[:, n:])) dldp = grad(loss_fn) update_fn = lambda model: model - lr * dldp(model) init = (model_init, np.zeros_like(model_init)) model_params, momentum_var = utils.solve( ss.momentum_bundler(update_fn, 0.9), init)[-1] P, r = ss.parse_model_params(n_states, n_actions, model_params) d0 = rnd.random((n_states, 1)) return utils.MDP(n_states, n_actions, P, r, discount, d0)