def loglik(params): lr1, lr2, B1, B2, td, w, persev = params X, U, X_, U_, R = data.unpack_tensor(env.nstates, env.nactions, get='sasar') X = np.squeeze(X) U = np.squeeze(U) X_ = np.squeeze(X_) U_ = np.squeeze(U_) R = R.flatten() T = env.T T[:, 1:, 1:] = 0 Qmf = np.zeros((env.nactions, env.nstates)) Q = np.zeros((env.nactions, env.nstates)) L = 0 a_last = np.zeros(env.nactions) for t in range(R.size): x = X[t] u = U[t] r = R[t] x_ = X_[t] u_ = U_[t] q = np.einsum('ij,j->i', Q, x) q_ = np.einsum('ij,j->i', Qmf, x_) logits1 = B1 * q + persev * a_last logits2 = B2 * q_ pu1 = fu.softmax(logits1) pu2 = fu.softmax(logits2) L += np.dot(u, logits1 - fu.logsumexp(logits1)) L += np.dot(u_, logits2 - fu.logsumexp(logits2)) z1 = np.outer(u, x) z2 = np.outer(u_, x_) + td * z1 Qmf = Qmf + lr1 * (u_.T @ Qmf @ x_) * z1 - lr1 * ( u.T @ Qmf @ x) * z1 + lr2 * r * z2 - lr2 * (u_.T @ Qmf @ x_) * z2 maxQmf = np.max(Qmf, axis=0) Qmb = np.einsum('ijk,j->ik', T, maxQmf) Q = w * Qmb + (1 - w) * Qmf a_last = u return L
def _log_prob_noderivatives(self, x): """ Computes the log-probability of an action $\mathbf u$ without computing derivatives. This is here only to facilitate unit testing of the `.log_prob` method by comparison against `autograd`. """ # Compute logits self.logits = self.inverse_softmax_temp*x # Compute log-probability of actions LSE = fu.logsumexp(self.logits) if not np.isfinite(LSE): LSE = 0. return self.logits - LSE
def log_prob(self, x): """ Computes the log-probability of an action $\mathbf u$ $$ \log p(\mathbf u|\mathbf v, \mathbf u_{t-1}) = \\big(\\beta \mathbf v + \\beta^\\rho \mathbf u_{t-1}) - \log \sum_{v_i} e^{\\beta \mathbf v_i + \\beta^\\rho u_{t-1}^{(i)}} $$ Arguments: x: State vector of type `ndarray((nactions,))` Returns: Scalar log-probability """ # Compute logits Bx = self.inverse_softmax_temp*x stickiness = self.perseveration*self.a_last self.logits = Bx + stickiness # Hessians HB, Hp, HBp, Hx, _ = hess.log_stickysoftmax(self.inverse_softmax_temp, self.perseveration, x, self.a_last) self.hess_logprob['inverse_softmax_temp'] = HB self.hess_logprob['perseveration'] = Hp self.hess_logprob['action_values'] = Hx self.hess_logprob['inverse_softmax_temp_perseveration'] = HBp # Derivatives # Grad LSE wrt Logits Dlse = grad.logsumexp(self.logits) # Grad logprob wrt logits self.d_logprob['logits'] = np.eye(x.size) - Dlse # Partial derivative with respect to inverse softmax temp self.d_logits['inverse_softmax_temp'] = x self.d_logits['perseveration'] = self.a_last self.d_logprob['inverse_softmax_temp'] = x - np.dot(Dlse, x) self.d_logprob['perseveration'] = self.a_last - np.dot(Dlse, self.a_last) # Gradient with respect to x B = np.eye(x.size)*self.inverse_softmax_temp Dlsetile = np.tile(self.inverse_softmax_temp*Dlse, [x.size, 1]) self.d_logprob['action_values'] = B - Dlsetile LSE = fu.logsumexp(self.logits) if not np.isfinite(LSE): LSE = 0. return self.logits - LSE
def log_prob(self, x): """ Computes the log-probability of an action $\mathbf u$, in addition to computing derivatives up to second order $$ \log p(\mathbf u|\mathbf v) = \\beta \mathbf v - \log \sum_{v_i} e^{\\beta \mathbf v_i} $$ Arguments: x: State vector of type `ndarray((nstates,))` Returns: Scalar log-probability """ # Compute logits self.logits = self.inverse_softmax_temp*x # Hessians HB, Hx = hess.log_softmax(self.inverse_softmax_temp, x) self.hess_logprob['inverse_softmax_temp'] = HB self.hess_logprob['action_values'] = Hx # Derivatives # Grad LSE wrt Logits Dlse = grad.logsumexp(self.logits) # Grad logprob wrt logits self.d_logprob['logits'] = np.eye(x.size) - Dlse # Grad logprob wrt inverse softmax temp self.d_logits['inverse_softmax_temp'] = x self.d_logprob['inverse_softmax_temp'] = np.dot(self.d_logprob['logits'], self.d_logits['inverse_softmax_temp']) # Grad logprob wrt action values `x` B = np.eye(x.size)*self.inverse_softmax_temp Dlsetile = np.tile(self.inverse_softmax_temp*Dlse, [x.size, 1]) self.d_logprob['action_values'] = B - Dlsetile # Compute log-probability of actions LSE = fu.logsumexp(self.logits) if not np.isfinite(LSE): LSE = 0. return self.logits - LSE
def log_prob(self, x): """ Computes the log-probability of an action $\mathbf u$ $$ \log p(\mathbf u|\mathbf v) = \\beta \mathbf v - \log \sum_{v_i} e^{\\beta \mathbf v_i} $$ Arguments: x: State vector of type `ndarray((nstates,))` Returns: Scalar log-probability """ xcor = x - np.max(x) # For stability Bx = self.inverse_softmax_temp * xcor LSE = logsumexp(Bx) if not np.isfinite(LSE): LSE = 0. return Bx - LSE
def log_prob(self, x): """ Computes the log-probability of an action $\mathbf u$ $$ \log p(\mathbf u|\mathbf v, \mathbf u_{t-1}) = \\big(\\beta \mathbf v + \\beta^\\rho \mathbf u_{t-1}) - \log \sum_{v_i} e^{\\beta \mathbf v_i + \\beta^\\rho u_{t-1}^{(i)}} $$ Arguments: x: State vector of type `ndarray((nstates,))` Returns: Scalar log-probability """ Bx = self.inverse_softmax_temp * x stickiness = self.perseveration * np.inner(self.a_last, self.a_last) x = Bx + stickiness x = x - np.max(x) LSE = logsumexp(x) if not np.isfinite(LSE): LSE = 0. return x - LSE
def f(w): Q = np.zeros((task.nactions, task.nstates)) L = 0 for t in range(R.size): x = X[0, t] u = U[0, t] r = R.flatten()[t] x_ = X_[0, t] u_ = U_[0, t] done = DONE.flatten()[t] logits = w[1] * np.einsum('ij,j->i', Q, x) lp = logits - fu.logsumexp(logits) L += np.einsum('i,i->', u, lp) #Reset trace if done == 0: z = np.zeros((task.nactions, task.nstates)) # Update trace z = np.outer(u, x) + w[2] * w[3] * z # Compute RPE rpe = r + w[2] * np.einsum('i,ij,j->', u_, Q, x_) - np.einsum( 'i,ij,j->', u, Q, x) # Update value function Q += w[0] * rpe * z return L
def test_logsumexp(): x = np.arange(5) assert np.equal(logsumexp(x), scipy_logsumexp(x))
[q_.size, 1]) Dlogit1_q1 = B1 Dlogit2_q2 = B2 Dlogit1_B1 = q Dlogit2_B2 = q_ Dlogit1_persev = a_last Dlp_q1 = B1 * np.eye(q.size) - np.tile(B1 * grad.logsumexp(logits1), [q.size, 1]) Dlp_q2 = B2 * np.eye(q_.size) - np.tile(B2 * grad.logsumexp(logits2), [q_.size, 1]) Dq1_Q = x Dq2_Q = x_ Dq2_Qmf = x_ DQ_w = Qmb - Qmf L += np.dot(u, logits1 - fu.logsumexp(logits1)) L += np.dot(u_, logits2 - fu.logsumexp(logits2)) Dlp_lr1 += np.dot(u, np.einsum('ij,jk,k->i', Dlp_q1, DQ_lr1, Dq1_Q)) Dlp_lr1 += np.dot(u_, np.einsum('ij,jk,k->i', Dlp_q2, DQmf_lr1, Dq2_Qmf)) Dlp_lr2 += np.dot(u, np.einsum('ij,jk,k->i', Dlp_q1, DQ_lr2, Dq1_Q)) Dlp_lr2 += np.dot(u_, np.einsum('ij,jk,k->i', Dlp_q2, DQmf_lr2, Dq2_Qmf)) Dlp_w += np.dot(u, np.dot(Dlp_q1, np.dot(DQ_w, Dq1_Q))) Dlp_td += np.dot( u, np.einsum('ij,jk,k->i', Dlp_q1, (DQ_Qmb * DQmb_Qmf + DQ_Qmf) * DQmf_td, Dq1_Q)) Dlp_td += np.dot(u_, np.einsum('ij,jk,k->i', Dlp_q2, DQmf_td, Dq2_Qmf)) Dlp_B1 += u @ Dlp_logit1 @ Dlogit1_B1 Dlp_persev += u @ Dlp_logit1 @ Dlogit1_persev
u = U[0, t] r = R.flatten()[t] x_ = X_[0, t] u_ = U_[0, t] done = DONE.flatten()[t] if done == 0: agent_inv.reset_trace() agent_inv.log_prob(x, u) agent_inv.learning(x, u, r, x_, u_) q = np.einsum('ij,j->i', Q, x) Dq_Q = x logits = B * q Dlogit_B = q Dlogit_q = B lp = logits - fu.logsumexp(logits) Dlp_logit = np.eye(logits.size) - np.tile( fu.softmax(logits).flatten(), [logits.size, 1]) L += np.einsum('i,i->', u, lp) pu = fu.softmax(logits) du = u - pu dpu_dlogit = grad.softmax(logits) DQ_lr_state = np.dot(DQ_lr, x) DQ_dc_state = np.dot(DQ_dc, x) DQ_td_state = np.dot(DQ_td, x) dpu_lr = B * np.einsum('ij,j->i', dpu_dlogit, DQ_lr_state) dpu_B = np.einsum('ij,j->i', dpu_dlogit, Dlogit_B) dpu_dc = B * np.einsum('ij,j->i', dpu_dlogit, DQ_dc_state) dpu_td = B * np.einsum('ij,j->i', dpu_dlogit, DQ_td_state) HB, Hq = hess.log_softmax(B, q)