Esempio n. 1
0
                         drift_mu=np.zeros(3),
                         drift_sd=1.)


data = generate_behavioural_data(MyBanditTask, RWSoftmaxAgent, 20, 200)


def log_prob(w, D):
    agent = RWSoftmaxAgent(task=MyBanditTask(),
                           learning_rate=w[0],
                           inverse_softmax_temp=w[1])
    L = 0
    for t in range(D.shape[0]):
        x = D[t, :7]
        u = D[t, 7:11]
        r = D[t, 11]
        x_ = D[t, 12:]
        L += u @ agent.log_prob(x)
        agent.learning(x, u, r, x_, None)
    return L


res = mlepar(log_prob, data.tensor, 2, maxstarts=5)
X = res.transform_xmin([sigmoid, relu])

# Criticism: Actual vs. Estimate Plots
lr_fig = actual_estimate(data.params[:, 1], X[:, 0])
plt.show()
ist_fig = actual_estimate(data.params[:, 2], X[:, 1])
plt.show()
Esempio n. 2
0
def test_actual_estimate():
    x = np.linspace(0, 10, 20)
    f = actual_estimate(x, x, xlabel='x', ylabel='y')
    del (f)
    plt.close()
Esempio n. 3
0
data = generate_behavioural_data(TwoArmedBandit, RWSoftmaxAgent, N, T)


# Create log-likelihood function
def log_prob(w, D):
    lr = sigmoid(w[0], a_min=-6, a_max=6)
    ist = stable_exp(w[1], a_min=-10, a_max=10)
    agent = RWSoftmaxAgent(TwoArmedBandit(), lr, ist)
    L = 0
    for t in range(D.shape[0]):
        x = D[t, :3]
        u = D[t, 3:5]
        r = D[t, 5]
        x_ = D[t, 6:]
        agent.log_prob(x, u)
        agent.learning(x, u, r, x_, None)
    J = np.array([grad.sigmoid(w[0]), grad.exp(w[1])])
    return -agent.logprob_, -J * agent.grad_,


# Fit model
res = mlepar(log_prob, data.tensor, nparams=2, maxstarts=5, jac=True)
X = res.transform_xmin([sigmoid, stable_exp])
idx = np.logical_and(np.logical_not(np.isnan(X[:, 0])), np.less(X[:, 1], 20))

# Criticism: Actual vs. Estimate Plots
lr_fig = actual_estimate(data.params[idx, 1], X[idx, 0])
plt.show()
ist_fig = actual_estimate(data.params[idx, 2], X[idx, 1])
plt.show()
                  maxstarts=4,
                  maxstarts_without_improvement=2,
                  init_sd=1,
                  njobs=-1,
                  jac=True,
                  hess=True,
                  method='trust-exact')

res_rwssm = mlepar(f=rwstickysoftmax_loglik,
                   data=data_rwsticky.tensor,
                   nparams=3,
                   minstarts=2,
                   maxstarts=4,
                   maxstarts_without_improvement=2,
                   init_sd=1,
                   njobs=-1,
                   jac=True,
                   hess=True,
                   method='trust-exact')

res.xmin

xhat = res.xmin[np.logical_not(np.any(np.isnan(res.xmin), axis=1)), :]
xhat = np.stack(
    fu.transform(xhat[i], [fu.sigmoid, fu.stable_exp]).flatten()
    for i in range(xhat.shape[0]))
xtrue = data.params[np.logical_not(np.any(np.isnan(res.xmin), axis=1)), 1:]

f = actual_estimate(xtrue[:, 0], xhat[:, 0])
f = actual_estimate(xtrue[xhat[:, 1] < 20, 1], xhat[xhat[:, 1] < 20, 1])
Esempio n. 5
0
        q.learning(X1[t], U1[t], R[t], X2[t], u2)
    L = q.logprob_
    g = q.grad_[[0, 1, 3]]
    H = np.hstack((q.hess_[:, :2], q.hess_[:, -1].reshape(-1, 1)))
    H = np.vstack((H[:2, :], H[-1, :].reshape(1, -1)))
    return -L, -J @ g, -J.T @ H @ J


res = mlepar(f=loglik,
             data=data.tensor,
             nparams=3,
             minstarts=2,
             maxstarts=10,
             maxstarts_without_improvement=2,
             init_sd=1,
             njobs=-1,
             jac=True,
             hess=True,
             method='trust-exact')

xhat = res.xmin[np.logical_not(np.any(np.isnan(res.xmin), axis=1)), :]
xhat = np.stack(
    fu.transform(xhat[i], [fu.sigmoid, fu.stable_exp, fu.sigmoid]).flatten()
    for i in range(xhat.shape[0]))
xtrue = data.params[np.logical_not(np.any(np.isnan(res.xmin), axis=1)), 1:]

f = actual_estimate(xtrue[:, 0], xhat[:, 0])
f = actual_estimate(xtrue[xhat[:, 3] < 20, 3], xhat[xhat[:, 3] < 20, 3])
f = actual_estimate(xtrue[xhat[:, 3] < 20, 1], xhat[xhat[:, 3] < 20, 1])
f = actual_estimate(xtrue[xhat[:, 3] < 20, 2], xhat[xhat[:, 3] < 20, 2])