Exemplo n.º 1
0
def loglik(params):
    lr1, lr2, B1, B2, td, w, persev = params
    X, U, X_, U_, R = data.unpack_tensor(env.nstates,
                                         env.nactions,
                                         get='sasar')
    X = np.squeeze(X)
    U = np.squeeze(U)
    X_ = np.squeeze(X_)
    U_ = np.squeeze(U_)
    R = R.flatten()

    T = env.T
    T[:, 1:, 1:] = 0
    Qmf = np.zeros((env.nactions, env.nstates))
    Q = np.zeros((env.nactions, env.nstates))
    L = 0
    a_last = np.zeros(env.nactions)
    for t in range(R.size):
        x = X[t]
        u = U[t]
        r = R[t]
        x_ = X_[t]
        u_ = U_[t]
        q = np.einsum('ij,j->i', Q, x)
        q_ = np.einsum('ij,j->i', Qmf, x_)

        logits1 = B1 * q + persev * a_last
        logits2 = B2 * q_
        pu1 = fu.softmax(logits1)
        pu2 = fu.softmax(logits2)
        L += np.dot(u, logits1 - fu.logsumexp(logits1))
        L += np.dot(u_, logits2 - fu.logsumexp(logits2))

        z1 = np.outer(u, x)
        z2 = np.outer(u_, x_) + td * z1
        Qmf = Qmf + lr1 * (u_.T @ Qmf @ x_) * z1 - lr1 * (
            u.T @ Qmf @ x) * z1 + lr2 * r * z2 - lr2 * (u_.T @ Qmf @ x_) * z2

        maxQmf = np.max(Qmf, axis=0)
        Qmb = np.einsum('ijk,j->ik', T, maxQmf)
        Q = w * Qmb + (1 - w) * Qmf

        a_last = u
    return L
Exemplo n.º 2
0
    def _log_prob_noderivatives(self, x):
        """ Computes the log-probability of an action $\mathbf u$ without computing derivatives.

        This is here only to facilitate unit testing of the `.log_prob` method by comparison against `autograd`.
        """
        # Compute logits
        self.logits  = self.inverse_softmax_temp*x

        # Compute log-probability of actions
        LSE = fu.logsumexp(self.logits)
        if not np.isfinite(LSE): LSE = 0.
        return self.logits - LSE
Exemplo n.º 3
0
    def log_prob(self, x):
        """ Computes the log-probability of an action $\mathbf u$

        $$
        \log p(\mathbf u|\mathbf v, \mathbf u_{t-1}) = \\big(\\beta \mathbf v + \\beta^\\rho \mathbf u_{t-1}) - \log \sum_{v_i} e^{\\beta \mathbf v_i + \\beta^\\rho u_{t-1}^{(i)}}
        $$

        Arguments:

            x: State vector of type `ndarray((nactions,))`

        Returns:

            Scalar log-probability
        """
        # Compute logits
        Bx  = self.inverse_softmax_temp*x
        stickiness = self.perseveration*self.a_last
        self.logits = Bx + stickiness

        # Hessians
        HB, Hp, HBp, Hx, _ = hess.log_stickysoftmax(self.inverse_softmax_temp,
                                                    self.perseveration,
                                                    x,
                                                    self.a_last)
        self.hess_logprob['inverse_softmax_temp'] = HB
        self.hess_logprob['perseveration'] = Hp
        self.hess_logprob['action_values'] = Hx
        self.hess_logprob['inverse_softmax_temp_perseveration'] = HBp

        # Derivatives
        #  Grad LSE wrt Logits
        Dlse = grad.logsumexp(self.logits)

        # Grad logprob wrt logits
        self.d_logprob['logits'] = np.eye(x.size) - Dlse

        #  Partial derivative with respect to inverse softmax temp
        self.d_logits['inverse_softmax_temp'] = x
        self.d_logits['perseveration'] = self.a_last
        self.d_logprob['inverse_softmax_temp'] = x - np.dot(Dlse, x)
        self.d_logprob['perseveration'] = self.a_last - np.dot(Dlse, self.a_last)

        # Gradient with respect to x
        B = np.eye(x.size)*self.inverse_softmax_temp
        Dlsetile = np.tile(self.inverse_softmax_temp*Dlse, [x.size, 1])
        self.d_logprob['action_values'] = B - Dlsetile

        LSE = fu.logsumexp(self.logits)
        if not np.isfinite(LSE): LSE = 0.
        return self.logits - LSE
Exemplo n.º 4
0
    def log_prob(self, x):
        """ Computes the log-probability of an action $\mathbf u$, in addition to computing derivatives up to second order

        $$
        \log p(\mathbf u|\mathbf v) = \\beta \mathbf v - \log \sum_{v_i} e^{\\beta \mathbf v_i}
        $$

        Arguments:

            x: State vector of type `ndarray((nstates,))`

        Returns:

            Scalar log-probability
        """
        # Compute logits
        self.logits  = self.inverse_softmax_temp*x

        # Hessians
        HB, Hx = hess.log_softmax(self.inverse_softmax_temp, x)
        self.hess_logprob['inverse_softmax_temp'] = HB
        self.hess_logprob['action_values'] = Hx

        # Derivatives
        #  Grad LSE wrt Logits
        Dlse = grad.logsumexp(self.logits)

        # Grad logprob wrt logits
        self.d_logprob['logits'] = np.eye(x.size) - Dlse

        #  Grad logprob wrt inverse softmax temp
        self.d_logits['inverse_softmax_temp'] = x
        self.d_logprob['inverse_softmax_temp'] = np.dot(self.d_logprob['logits'], self.d_logits['inverse_softmax_temp'])

        # Grad logprob wrt action values `x`
        B = np.eye(x.size)*self.inverse_softmax_temp
        Dlsetile = np.tile(self.inverse_softmax_temp*Dlse, [x.size, 1])
        self.d_logprob['action_values'] = B - Dlsetile

        # Compute log-probability of actions
        LSE = fu.logsumexp(self.logits)
        if not np.isfinite(LSE): LSE = 0.
        return self.logits - LSE
Exemplo n.º 5
0
    def log_prob(self, x):
        """ Computes the log-probability of an action $\mathbf u$

        $$
        \log p(\mathbf u|\mathbf v) = \\beta \mathbf v - \log \sum_{v_i} e^{\\beta \mathbf v_i}
        $$

        Arguments:

            x: State vector of type `ndarray((nstates,))`

        Returns:

            Scalar log-probability
        """
        xcor = x - np.max(x)  # For stability
        Bx = self.inverse_softmax_temp * xcor
        LSE = logsumexp(Bx)
        if not np.isfinite(LSE): LSE = 0.
        return Bx - LSE
Exemplo n.º 6
0
    def log_prob(self, x):
        """ Computes the log-probability of an action $\mathbf u$

        $$
        \log p(\mathbf u|\mathbf v, \mathbf u_{t-1}) = \\big(\\beta \mathbf v + \\beta^\\rho \mathbf u_{t-1}) - \log \sum_{v_i} e^{\\beta \mathbf v_i + \\beta^\\rho u_{t-1}^{(i)}}
        $$

        Arguments:

            x: State vector of type `ndarray((nstates,))`

        Returns:

            Scalar log-probability
        """
        Bx = self.inverse_softmax_temp * x
        stickiness = self.perseveration * np.inner(self.a_last, self.a_last)
        x = Bx + stickiness
        x = x - np.max(x)
        LSE = logsumexp(x)
        if not np.isfinite(LSE): LSE = 0.
        return x - LSE
def f(w):
    Q = np.zeros((task.nactions, task.nstates))
    L = 0
    for t in range(R.size):
        x = X[0, t]
        u = U[0, t]
        r = R.flatten()[t]
        x_ = X_[0, t]
        u_ = U_[0, t]
        done = DONE.flatten()[t]
        logits = w[1] * np.einsum('ij,j->i', Q, x)
        lp = logits - fu.logsumexp(logits)
        L += np.einsum('i,i->', u, lp)
        #Reset trace
        if done == 0:
            z = np.zeros((task.nactions, task.nstates))
        # Update trace
        z = np.outer(u, x) + w[2] * w[3] * z
        # Compute RPE
        rpe = r + w[2] * np.einsum('i,ij,j->', u_, Q, x_) - np.einsum(
            'i,ij,j->', u, Q, x)
        # Update value function
        Q += w[0] * rpe * z
    return L
Exemplo n.º 8
0
def test_logsumexp():
    x = np.arange(5)
    assert np.equal(logsumexp(x), scipy_logsumexp(x))
Exemplo n.º 9
0
                                           [q_.size, 1])
    Dlogit1_q1 = B1
    Dlogit2_q2 = B2
    Dlogit1_B1 = q
    Dlogit2_B2 = q_
    Dlogit1_persev = a_last
    Dlp_q1 = B1 * np.eye(q.size) - np.tile(B1 * grad.logsumexp(logits1),
                                           [q.size, 1])
    Dlp_q2 = B2 * np.eye(q_.size) - np.tile(B2 * grad.logsumexp(logits2),
                                            [q_.size, 1])
    Dq1_Q = x
    Dq2_Q = x_
    Dq2_Qmf = x_
    DQ_w = Qmb - Qmf

    L += np.dot(u, logits1 - fu.logsumexp(logits1))
    L += np.dot(u_, logits2 - fu.logsumexp(logits2))

    Dlp_lr1 += np.dot(u, np.einsum('ij,jk,k->i', Dlp_q1, DQ_lr1, Dq1_Q))
    Dlp_lr1 += np.dot(u_, np.einsum('ij,jk,k->i', Dlp_q2, DQmf_lr1, Dq2_Qmf))
    Dlp_lr2 += np.dot(u, np.einsum('ij,jk,k->i', Dlp_q1, DQ_lr2, Dq1_Q))
    Dlp_lr2 += np.dot(u_, np.einsum('ij,jk,k->i', Dlp_q2, DQmf_lr2, Dq2_Qmf))
    Dlp_w += np.dot(u, np.dot(Dlp_q1, np.dot(DQ_w, Dq1_Q)))
    Dlp_td += np.dot(
        u,
        np.einsum('ij,jk,k->i', Dlp_q1, (DQ_Qmb * DQmb_Qmf + DQ_Qmf) * DQmf_td,
                  Dq1_Q))
    Dlp_td += np.dot(u_, np.einsum('ij,jk,k->i', Dlp_q2, DQmf_td, Dq2_Qmf))

    Dlp_B1 += u @ Dlp_logit1 @ Dlogit1_B1
    Dlp_persev += u @ Dlp_logit1 @ Dlogit1_persev
    u = U[0, t]
    r = R.flatten()[t]
    x_ = X_[0, t]
    u_ = U_[0, t]
    done = DONE.flatten()[t]

    if done == 0: agent_inv.reset_trace()
    agent_inv.log_prob(x, u)
    agent_inv.learning(x, u, r, x_, u_)

    q = np.einsum('ij,j->i', Q, x)
    Dq_Q = x
    logits = B * q
    Dlogit_B = q
    Dlogit_q = B
    lp = logits - fu.logsumexp(logits)
    Dlp_logit = np.eye(logits.size) - np.tile(
        fu.softmax(logits).flatten(), [logits.size, 1])
    L += np.einsum('i,i->', u, lp)
    pu = fu.softmax(logits)
    du = u - pu
    dpu_dlogit = grad.softmax(logits)
    DQ_lr_state = np.dot(DQ_lr, x)
    DQ_dc_state = np.dot(DQ_dc, x)
    DQ_td_state = np.dot(DQ_td, x)
    dpu_lr = B * np.einsum('ij,j->i', dpu_dlogit, DQ_lr_state)
    dpu_B = np.einsum('ij,j->i', dpu_dlogit, Dlogit_B)
    dpu_dc = B * np.einsum('ij,j->i', dpu_dlogit, DQ_dc_state)
    dpu_td = B * np.einsum('ij,j->i', dpu_dlogit, DQ_td_state)

    HB, Hq = hess.log_softmax(B, q)