Beispiel #1
0
def test_grad_twostep_stickysoftmax_bellmanmax_agent():
    # Generate some data
    task = DawTwoStep(rng=np.random.RandomState(23))
    agent = TwoStepStickySoftmaxSARSABellmanMaxAgent(
        task, 0.1, 0.2, 2., 1.5, 1., 0.5, 0.01, rng=np.random.RandomState(743))
    X = []
    U = []
    X_ = []
    U_ = []
    R = []
    for t in range(20):
        x = task.observation()
        u = agent.action_step1(x)
        x_, _, _ = task.step(u)
        u_ = agent.action_step2(x_)
        _, r, _ = task.step(u_)
        agent.learning(x, u, x_, u_, r)
        X.append(x)
        U.append(u)
        X_.append(x_)
        U_.append(u_)
        R.append(r)

    # Define function for testing with autograd
    def f(w):
        agent = TwoStepStickySoftmaxSARSABellmanMaxAgent(
            DawTwoStep(),
            learning_rate_1=w[0],
            learning_rate_2=w[1],
            inverse_softmax_temp_1=w[2],
            inverse_softmax_temp_2=w[3],
            trace_decay=w[4],
            mb_weight=w[5],
            perseveration=w[6])
        for t in range(20):
            agent._log_prob_noderivatives(X[t], U[t], X_[t], U_[t])
            agent._learning_noderivatives(X[t], U[t], X_[t], U_[t], R[t])

        return -agent.logprob_

    # Compute gradient with fitr
    w = np.array([0.1, 0.2, 2., 1.5, 1., 0.5, 0.01])
    agent = TwoStepStickySoftmaxSARSABellmanMaxAgent(
        DawTwoStep(),
        learning_rate_1=w[0],
        learning_rate_2=w[1],
        inverse_softmax_temp_1=w[2],
        inverse_softmax_temp_2=w[3],
        trace_decay=w[4],
        mb_weight=w[5],
        perseveration=w[6])
    for t in range(20):
        agent.log_prob(X[t], U[t], X_[t], U_[t])
        agent.learning(X[t], U[t], X_[t], U_[t], R[t])

    # Check that the gradients are the same
    J_fitr = -agent.grad_
    J_ag = jacobian(f)(w)
    assert (np.linalg.norm(J_ag - J_fitr) < 1e-6)
Beispiel #2
0
    def f(w):
        agent = TwoStepStickySoftmaxSARSABellmanMaxAgent(
            DawTwoStep(),
            learning_rate_1=w[0],
            learning_rate_2=w[1],
            inverse_softmax_temp_1=w[2],
            inverse_softmax_temp_2=w[3],
            trace_decay=w[4],
            mb_weight=w[5],
            perseveration=w[6])
        for t in range(20):
            agent._log_prob_noderivatives(X[t], U[t], X_[t], U_[t])
            agent._learning_noderivatives(X[t], U[t], X_[t], U_[t], R[t])

        return -agent.logprob_
Beispiel #3
0
 def agf_et(et):
     X1, X2, U1, U2, X3, R = make_mdp_trials()
     q = QLearner(DawTwoStep(),
                  learning_rate=0.1,
                  discount_factor=0.9,
                  trace_decay=et)
     for i in range(ntrials):
         q.etrace = np.zeros((2, 5))
         x = X1[i]
         u = U1[i]
         x_ = X2[i]
         r = R @ x_
         q._update_noderivatives(x, u, r, x_, None)
         u_ = U2[i]
         x = x_
         u = u_
         x_ = X3[i]
         r = R @ x_
         q._update_noderivatives(x, u, r, x_, None)
     return q.Q
Beispiel #4
0
 def agf_dc(dc):
     X1, X2, U1, U2, X3, U3, R = make_mdp_trials()
     q = SARSALearner(DawTwoStep(),
                      learning_rate=0.1,
                      discount_factor=dc,
                      trace_decay=0.95)
     for i in range(ntrials):
         q.etrace = np.zeros(q.Q.shape)
         x = X1[i]
         u = U1[i]
         x_ = X2[i]
         r = R @ x_
         u_ = U2[i]
         q._update_noderivatives(x, u, r, x_, u_)
         x = x_
         u = u_
         x_ = X3[i]
         u_ = U3[i]
         r = R @ x_
         q._update_noderivatives(x, u, r, x_, u_)
     return q.Q
Beispiel #5
0
 def f(w):
     X1, X2, U1, U2, X3, U3, R = make_mdp_trials()
     q = SARSASoftmaxAgent(DawTwoStep(),
                           learning_rate=w[0],
                           inverse_softmax_temp=w[1],
                           discount_factor=w[2],
                           trace_decay=w[3])
     for i in range(ntrials):
         q.reset_trace()
         x = X1[i]
         u = U1[i]
         x_ = X2[i]
         r = R @ x_
         u_ = U2[i]
         q._log_prob_noderivatives(x, u)
         q.critic._update_noderivatives(x, u, r, x_, u_)
         x = x_
         u = u_
         x_ = X3[i]
         u_ = U3[i]
         r = R @ x_
         q._log_prob_noderivatives(x, u)
         q.critic._update_noderivatives(x, u, r, x_, u_)
     return q.logprob_
Beispiel #6
0
def test_grad_sarsasoftmaxagent():
    ntrials = 7

    def make_mdp_trials():
        rng = np.random.RandomState(3256)
        X1 = np.tile(np.array([1., 0., 0., 0., 0.]), [ntrials, 1])
        X2 = rng.multinomial(1, pvals=[0., 0.5, 0.5, 0., 0.], size=ntrials)
        U1 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials)
        U2 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials)
        X3 = rng.multinomial(1, pvals=[0., 0., 0., 0.5, 0.5], size=ntrials)
        U3 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials)
        R = np.array([0., 0., 0., 1., 0.])
        return X1, X2, U1, U2, X3, U3, R

    w = np.array([0.1, 2., 0.9, 0.95])

    # GRADIENTS WITH FITR
    X1, X2, U1, U2, X3, U3, R = make_mdp_trials()
    q = SARSASoftmaxAgent(DawTwoStep(),
                          learning_rate=w[0],
                          inverse_softmax_temp=w[1],
                          discount_factor=w[2],
                          trace_decay=w[3])
    i = 0
    for i in range(ntrials):
        x = X1[i]
        u = U1[i]
        x_ = X2[i]
        r = R @ x_
        u_ = U2[i]
        q.reset_trace()
        q.log_prob(x, u)
        q.learning(x, u, r, x_, u_)

        x = x_
        u = u_
        x_ = X3[i]
        u_ = U3[i]
        r = R @ x_
        q.log_prob(x, u)
        q.learning(x, u, r, x_, u_)

    # AUTOGRAD
    def f(w):
        X1, X2, U1, U2, X3, U3, R = make_mdp_trials()
        q = SARSASoftmaxAgent(DawTwoStep(),
                              learning_rate=w[0],
                              inverse_softmax_temp=w[1],
                              discount_factor=w[2],
                              trace_decay=w[3])
        for i in range(ntrials):
            q.reset_trace()
            x = X1[i]
            u = U1[i]
            x_ = X2[i]
            r = R @ x_
            u_ = U2[i]
            q._log_prob_noderivatives(x, u)
            q.critic._update_noderivatives(x, u, r, x_, u_)
            x = x_
            u = u_
            x_ = X3[i]
            u_ = U3[i]
            r = R @ x_
            q._log_prob_noderivatives(x, u)
            q.critic._update_noderivatives(x, u, r, x_, u_)
        return q.logprob_

    ag = jacobian(f)(w)
    aH = hessian(f)(w)
    assert (np.linalg.norm(q.grad_ - ag) < 1e-6)
    assert (np.linalg.norm(q.hess_ - aH) < 1e-6)
Beispiel #7
0
def test_grad_sarsalearnerupdate():
    ntrials = 7

    def make_mdp_trials():
        rng = np.random.RandomState(3256)
        X1 = np.tile(np.array([1., 0., 0., 0., 0.]), [ntrials, 1])
        X2 = rng.multinomial(1, pvals=[0., 0.5, 0.5, 0., 0.], size=ntrials)
        U1 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials)
        U2 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials)
        X3 = rng.multinomial(1, pvals=[0., 0., 0., 0.5, 0.5], size=ntrials)
        U3 = rng.multinomial(1, pvals=[0.5, 0.5], size=ntrials)
        R = np.array([0., 0., 0., 1., 0.])
        return X1, X2, U1, U2, X3, U3, R

    # GRADIENTS WITH FITR
    X1, X2, U1, U2, X3, U3, R = make_mdp_trials()
    q = SARSALearner(DawTwoStep(),
                     learning_rate=0.1,
                     discount_factor=0.9,
                     trace_decay=0.95)
    for i in range(ntrials):
        q.etrace = np.zeros(q.Q.shape)
        x = X1[i]
        u = U1[i]
        x_ = X2[i]
        r = R @ x_
        u_ = U2[i]
        q.update(x, u, r, x_, u_)
        x = x_
        u = u_
        x_ = X3[i]
        u_ = U3[i]
        r = R @ x_
        q.update(x, u, r, x_, u_)

    # AUTOGRAD
    def agf_lr(lr):
        X1, X2, U1, U2, X3, U3, R = make_mdp_trials()
        q = SARSALearner(DawTwoStep(),
                         learning_rate=lr,
                         discount_factor=0.9,
                         trace_decay=0.95)
        for i in range(ntrials):
            q.etrace = np.zeros(q.Q.shape)
            x = X1[i]
            u = U1[i]
            x_ = X2[i]
            r = R @ x_
            u_ = U2[i]
            q._update_noderivatives(x, u, r, x_, u_)
            x = x_
            u = u_
            x_ = X3[i]
            u_ = U3[i]
            r = R @ x_
            q._update_noderivatives(x, u, r, x_, u_)
        return q.Q

    def agf_dc(dc):
        X1, X2, U1, U2, X3, U3, R = make_mdp_trials()
        q = SARSALearner(DawTwoStep(),
                         learning_rate=0.1,
                         discount_factor=dc,
                         trace_decay=0.95)
        for i in range(ntrials):
            q.etrace = np.zeros(q.Q.shape)
            x = X1[i]
            u = U1[i]
            x_ = X2[i]
            r = R @ x_
            u_ = U2[i]
            q._update_noderivatives(x, u, r, x_, u_)
            x = x_
            u = u_
            x_ = X3[i]
            u_ = U3[i]
            r = R @ x_
            q._update_noderivatives(x, u, r, x_, u_)
        return q.Q

    def agf_et(et):
        X1, X2, U1, U2, X3, U3, R = make_mdp_trials()
        q = SARSALearner(DawTwoStep(),
                         learning_rate=0.1,
                         discount_factor=0.9,
                         trace_decay=et)
        for i in range(ntrials):
            q.etrace = np.zeros(q.Q.shape)
            x = X1[i]
            u = U1[i]
            x_ = X2[i]
            r = R @ x_
            u_ = U2[i]
            q._update_noderivatives(x, u, r, x_, u_)
            x = x_
            u = u_
            x_ = X3[i]
            u_ = U3[i]
            r = R @ x_
            q._update_noderivatives(x, u, r, x_, u_)
        return q.Q

    # Ensure all are producing same value functions
    qlist = [agf_lr(0.1), agf_dc(0.9), agf_et(0.95), q.Q]
    assert (np.all(
        np.stack(np.all(np.equal(a, b)) for a in qlist for b in qlist)))

    # Check partial derivative of Q with respect to learning rate
    assert (np.linalg.norm(q.dQ['learning_rate'] - jacobian(agf_lr)(0.1)) <
            1e-6)

    # Check partial derivative of Q with respect to discount factor
    assert (np.linalg.norm(q.dQ['discount_factor'] - jacobian(agf_dc)(0.9)) <
            1e-6)

    # Check partial derivative of Q with respect to trace decay
    assert (np.linalg.norm(q.dQ['trace_decay'] - jacobian(agf_et)(0.95)) <
            1e-6)
Beispiel #8
0
def test_set_seed():
    task = DawTwoStep()
    task.set_seed(235)

    task2 = DawTwoStep()
    task2.set_seed(235)

    state1 = task.rng.get_state()
    state2 = task2.rng.get_state()
    assert(np.all(state1[1] == state2[1]))

    # Get graph depth
    d = task.get_graph_depth()

    # Test figures
    f = task.plot_graph()
    del(f)

    f = task.plot_spectral_properties()
    del(f)

    f = task.plot_action_outcome_probabilities(outfile=None)
    del(f)
Beispiel #9
0
import autograd.numpy as np
from autograd import jacobian, hessian, elementwise_grad
import fitr.utils as fu
import fitr.gradients as grad
from fitr.environments import DawTwoStep
from fitr.data import BehaviouralData

ntrials = 201
env = DawTwoStep(rng=np.random.RandomState(436))

data = BehaviouralData(1)

lr1 = 0.1
lr2 = 0.1
B1 = 2.
persev = 0.0
B2 = 2
td = 0.95
w = 0.1
persev = 0.1
par = np.array([lr1, lr2, B1, B2, td, w, persev])
rng = np.random.RandomState(32)
T = env.T
Qmf = np.zeros((env.nactions, env.nstates))
Q = np.zeros((env.nactions, env.nstates))

data.add_subject(subject_index=0, parameters=par, subject_meta=[])
a_last = np.zeros(env.nactions)
for t in range(ntrials):
    x = env.observation()
    q = np.einsum('ij,j->i', Q, x)
import autograd.numpy as np
from autograd import jacobian, hessian
import fitr.utils as fu
import fitr.gradients as grad
import fitr.hessians as hess
from fitr.environments import DawTwoStep
from fitr.agents import SARSASoftmaxAgent

ntrials = 20
lr = 0.1
B = 2.
dc = 0.9
td = 0.95
w = np.array([lr, B, dc, td])

task = DawTwoStep(rng=np.random.RandomState(532))
agent = SARSASoftmaxAgent(DawTwoStep(rng=np.random.RandomState(532)),
                          learning_rate=lr,
                          inverse_softmax_temp=B,
                          discount_factor=dc,
                          trace_decay=td,
                          rng=np.random.RandomState(236))
data = agent.generate_data(ntrials)

agent_inv = SARSASoftmaxAgent(DawTwoStep(rng=np.random.RandomState(532)),
                              learning_rate=lr,
                              inverse_softmax_temp=B,
                              discount_factor=dc,
                              trace_decay=td,
                              rng=np.random.RandomState(236))