def test_grad_twostep_stickysoftmax_bellmanmax_agent(): # Generate some data task = DawTwoStep(rng=np.random.RandomState(23)) agent = TwoStepStickySoftmaxSARSABellmanMaxAgent( task, 0.1, 0.2, 2., 1.5, 1., 0.5, 0.01, rng=np.random.RandomState(743)) X = [] U = [] X_ = [] U_ = [] R = [] for t in range(20): x = task.observation() u = agent.action_step1(x) x_, _, _ = task.step(u) u_ = agent.action_step2(x_) _, r, _ = task.step(u_) agent.learning(x, u, x_, u_, r) X.append(x) U.append(u) X_.append(x_) U_.append(u_) R.append(r) # Define function for testing with autograd def f(w): agent = TwoStepStickySoftmaxSARSABellmanMaxAgent( DawTwoStep(), learning_rate_1=w[0], learning_rate_2=w[1], inverse_softmax_temp_1=w[2], inverse_softmax_temp_2=w[3], trace_decay=w[4], mb_weight=w[5], perseveration=w[6]) for t in range(20): agent._log_prob_noderivatives(X[t], U[t], X_[t], U_[t]) agent._learning_noderivatives(X[t], U[t], X_[t], U_[t], R[t]) return -agent.logprob_ # Compute gradient with fitr w = np.array([0.1, 0.2, 2., 1.5, 1., 0.5, 0.01]) agent = TwoStepStickySoftmaxSARSABellmanMaxAgent( DawTwoStep(), learning_rate_1=w[0], learning_rate_2=w[1], inverse_softmax_temp_1=w[2], inverse_softmax_temp_2=w[3], trace_decay=w[4], mb_weight=w[5], perseveration=w[6]) for t in range(20): agent.log_prob(X[t], U[t], X_[t], U_[t]) agent.learning(X[t], U[t], X_[t], U_[t], R[t]) # Check that the gradients are the same J_fitr = -agent.grad_ J_ag = jacobian(f)(w) assert (np.linalg.norm(J_ag - J_fitr) < 1e-6)
def f(w): agent = TwoStepStickySoftmaxSARSABellmanMaxAgent( DawTwoStep(), learning_rate_1=w[0], learning_rate_2=w[1], inverse_softmax_temp_1=w[2], inverse_softmax_temp_2=w[3], trace_decay=w[4], mb_weight=w[5], perseveration=w[6]) for t in range(20): agent._log_prob_noderivatives(X[t], U[t], X_[t], U_[t]) agent._learning_noderivatives(X[t], U[t], X_[t], U_[t], R[t]) return -agent.logprob_
def agf_et(et): X1, X2, U1, U2, X3, R = make_mdp_trials() q = QLearner(DawTwoStep(), learning_rate=0.1, discount_factor=0.9, trace_decay=et) for i in range(ntrials): q.etrace = np.zeros((2, 5)) x = X1[i] u = U1[i] x_ = X2[i] r = R @ x_ q._update_noderivatives(x, u, r, x_, None) u_ = U2[i] x = x_ u = u_ x_ = X3[i] r = R @ x_ q._update_noderivatives(x, u, r, x_, None) return q.Q
def agf_dc(dc): X1, X2, U1, U2, X3, U3, R = make_mdp_trials() q = SARSALearner(DawTwoStep(), learning_rate=0.1, discount_factor=dc, trace_decay=0.95) for i in range(ntrials): q.etrace = np.zeros(q.Q.shape) x = X1[i] u = U1[i] x_ = X2[i] r = R @ x_ u_ = U2[i] q._update_noderivatives(x, u, r, x_, u_) x = x_ u = u_ x_ = X3[i] u_ = U3[i] r = R @ x_ q._update_noderivatives(x, u, r, x_, u_) return q.Q
def f(w): X1, X2, U1, U2, X3, U3, R = make_mdp_trials() q = SARSASoftmaxAgent(DawTwoStep(), learning_rate=w[0], inverse_softmax_temp=w[1], discount_factor=w[2], trace_decay=w[3]) for i in range(ntrials): q.reset_trace() x = X1[i] u = U1[i] x_ = X2[i] r = R @ x_ u_ = U2[i] q._log_prob_noderivatives(x, u) q.critic._update_noderivatives(x, u, r, x_, u_) x = x_ u = u_ x_ = X3[i] u_ = U3[i] r = R @ x_ q._log_prob_noderivatives(x, u) q.critic._update_noderivatives(x, u, r, x_, u_) return q.logprob_
def test_grad_sarsasoftmaxagent(): ntrials = 7 def make_mdp_trials(): rng = np.random.RandomState(3256) X1 = np.tile(np.array([1., 0., 0., 0., 0.]), [ntrials, 1]) X2 = rng.multinomial(1, pvals=[0., 0.5, 0.5, 0., 0.], size=ntrials) U1 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials) U2 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials) X3 = rng.multinomial(1, pvals=[0., 0., 0., 0.5, 0.5], size=ntrials) U3 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials) R = np.array([0., 0., 0., 1., 0.]) return X1, X2, U1, U2, X3, U3, R w = np.array([0.1, 2., 0.9, 0.95]) # GRADIENTS WITH FITR X1, X2, U1, U2, X3, U3, R = make_mdp_trials() q = SARSASoftmaxAgent(DawTwoStep(), learning_rate=w[0], inverse_softmax_temp=w[1], discount_factor=w[2], trace_decay=w[3]) i = 0 for i in range(ntrials): x = X1[i] u = U1[i] x_ = X2[i] r = R @ x_ u_ = U2[i] q.reset_trace() q.log_prob(x, u) q.learning(x, u, r, x_, u_) x = x_ u = u_ x_ = X3[i] u_ = U3[i] r = R @ x_ q.log_prob(x, u) q.learning(x, u, r, x_, u_) # AUTOGRAD def f(w): X1, X2, U1, U2, X3, U3, R = make_mdp_trials() q = SARSASoftmaxAgent(DawTwoStep(), learning_rate=w[0], inverse_softmax_temp=w[1], discount_factor=w[2], trace_decay=w[3]) for i in range(ntrials): q.reset_trace() x = X1[i] u = U1[i] x_ = X2[i] r = R @ x_ u_ = U2[i] q._log_prob_noderivatives(x, u) q.critic._update_noderivatives(x, u, r, x_, u_) x = x_ u = u_ x_ = X3[i] u_ = U3[i] r = R @ x_ q._log_prob_noderivatives(x, u) q.critic._update_noderivatives(x, u, r, x_, u_) return q.logprob_ ag = jacobian(f)(w) aH = hessian(f)(w) assert (np.linalg.norm(q.grad_ - ag) < 1e-6) assert (np.linalg.norm(q.hess_ - aH) < 1e-6)
def test_grad_sarsalearnerupdate(): ntrials = 7 def make_mdp_trials(): rng = np.random.RandomState(3256) X1 = np.tile(np.array([1., 0., 0., 0., 0.]), [ntrials, 1]) X2 = rng.multinomial(1, pvals=[0., 0.5, 0.5, 0., 0.], size=ntrials) U1 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials) U2 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials) X3 = rng.multinomial(1, pvals=[0., 0., 0., 0.5, 0.5], size=ntrials) U3 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials) R = np.array([0., 0., 0., 1., 0.]) return X1, X2, U1, U2, X3, U3, R # GRADIENTS WITH FITR X1, X2, U1, U2, X3, U3, R = make_mdp_trials() q = SARSALearner(DawTwoStep(), learning_rate=0.1, discount_factor=0.9, trace_decay=0.95) for i in range(ntrials): q.etrace = np.zeros(q.Q.shape) x = X1[i] u = U1[i] x_ = X2[i] r = R @ x_ u_ = U2[i] q.update(x, u, r, x_, u_) x = x_ u = u_ x_ = X3[i] u_ = U3[i] r = R @ x_ q.update(x, u, r, x_, u_) # AUTOGRAD def agf_lr(lr): X1, X2, U1, U2, X3, U3, R = make_mdp_trials() q = SARSALearner(DawTwoStep(), learning_rate=lr, discount_factor=0.9, trace_decay=0.95) for i in range(ntrials): q.etrace = np.zeros(q.Q.shape) x = X1[i] u = U1[i] x_ = X2[i] r = R @ x_ u_ = U2[i] q._update_noderivatives(x, u, r, x_, u_) x = x_ u = u_ x_ = X3[i] u_ = U3[i] r = R @ x_ q._update_noderivatives(x, u, r, x_, u_) return q.Q def agf_dc(dc): X1, X2, U1, U2, X3, U3, R = make_mdp_trials() q = SARSALearner(DawTwoStep(), learning_rate=0.1, discount_factor=dc, trace_decay=0.95) for i in range(ntrials): q.etrace = np.zeros(q.Q.shape) x = X1[i] u = U1[i] x_ = X2[i] r = R @ x_ u_ = U2[i] q._update_noderivatives(x, u, r, x_, u_) x = x_ u = u_ x_ = X3[i] u_ = U3[i] r = R @ x_ q._update_noderivatives(x, u, r, x_, u_) return q.Q def agf_et(et): X1, X2, U1, U2, X3, U3, R = make_mdp_trials() q = SARSALearner(DawTwoStep(), learning_rate=0.1, discount_factor=0.9, trace_decay=et) for i in range(ntrials): q.etrace = np.zeros(q.Q.shape) x = X1[i] u = U1[i] x_ = X2[i] r = R @ x_ u_ = U2[i] q._update_noderivatives(x, u, r, x_, u_) x = x_ u = u_ x_ = X3[i] u_ = U3[i] r = R @ x_ q._update_noderivatives(x, u, r, x_, u_) return q.Q # Ensure all are producing same value functions qlist = [agf_lr(0.1), agf_dc(0.9), agf_et(0.95), q.Q] assert (np.all( np.stack(np.all(np.equal(a, b)) for a in qlist for b in qlist))) # Check partial derivative of Q with respect to learning rate assert (np.linalg.norm(q.dQ['learning_rate'] - jacobian(agf_lr)(0.1)) < 1e-6) # Check partial derivative of Q with respect to discount factor assert (np.linalg.norm(q.dQ['discount_factor'] - jacobian(agf_dc)(0.9)) < 1e-6) # Check partial derivative of Q with respect to trace decay assert (np.linalg.norm(q.dQ['trace_decay'] - jacobian(agf_et)(0.95)) < 1e-6)
def test_set_seed(): task = DawTwoStep() task.set_seed(235) task2 = DawTwoStep() task2.set_seed(235) state1 = task.rng.get_state() state2 = task2.rng.get_state() assert(np.all(state1[1] == state2[1])) # Get graph depth d = task.get_graph_depth() # Test figures f = task.plot_graph() del(f) f = task.plot_spectral_properties() del(f) f = task.plot_action_outcome_probabilities(outfile=None) del(f)
import autograd.numpy as np from autograd import jacobian, hessian, elementwise_grad import fitr.utils as fu import fitr.gradients as grad from fitr.environments import DawTwoStep from fitr.data import BehaviouralData ntrials = 201 env = DawTwoStep(rng=np.random.RandomState(436)) data = BehaviouralData(1) lr1 = 0.1 lr2 = 0.1 B1 = 2. persev = 0.0 B2 = 2 td = 0.95 w = 0.1 persev = 0.1 par = np.array([lr1, lr2, B1, B2, td, w, persev]) rng = np.random.RandomState(32) T = env.T Qmf = np.zeros((env.nactions, env.nstates)) Q = np.zeros((env.nactions, env.nstates)) data.add_subject(subject_index=0, parameters=par, subject_meta=[]) a_last = np.zeros(env.nactions) for t in range(ntrials): x = env.observation() q = np.einsum('ij,j->i', Q, x)
import autograd.numpy as np from autograd import jacobian, hessian import fitr.utils as fu import fitr.gradients as grad import fitr.hessians as hess from fitr.environments import DawTwoStep from fitr.agents import SARSASoftmaxAgent ntrials = 20 lr = 0.1 B = 2. dc = 0.9 td = 0.95 w = np.array([lr, B, dc, td]) task = DawTwoStep(rng=np.random.RandomState(532)) agent = SARSASoftmaxAgent(DawTwoStep(rng=np.random.RandomState(532)), learning_rate=lr, inverse_softmax_temp=B, discount_factor=dc, trace_decay=td, rng=np.random.RandomState(236)) data = agent.generate_data(ntrials) agent_inv = SARSASoftmaxAgent(DawTwoStep(rng=np.random.RandomState(532)), learning_rate=lr, inverse_softmax_temp=B, discount_factor=dc, trace_decay=td, rng=np.random.RandomState(236))