Example #1
0
    def session_likelihood(self, session, params_T):#, return_trial_data = False):

        # Unpack trial events.
        choices, second_steps, outcomes = ut.CTSO_unpack(session.CTSO, 'CSO')

        # Unpack parameters.
        alpha, iTemp, lambd, W, tlr, D, tdec, A, alr = params_T[:9]   # Learning rate for arbitration.
        if self.use_kernels: bias, CK, SSK  = params_T[-3:]

        #Variables.
        n_trials = len(choices)
        Q_td_f = np.zeros([n_trials + 1 , 2])       # Model free first step action values (low, high).
        Q_td_s = np.zeros([n_trials + 1 , 2])       # Model free second step action values (right, left).
        arb    = np.zeros(n_trials + 1)             # Arbitration parameter, positive means more model based.
        trans_probs = np.zeros([n_trials + 1 , 2])  # Transition probabilities for low and high pokes.
        trans_probs[0,:] = 0.5  # Initialize first trial transition probabilities.

        for i, (c, s, o) in enumerate(zip(choices, second_steps, outcomes)): # loop over trials.

            nc = 1 - c  # Action not chosen at first step.
            ns = 1 - s  # State not reached at second step.

            # Update model free action values. 

            Q_td_f[i+1,nc] = Q_td_f[i, nc] * (1. - D)   # First step forgetting.
            Q_td_s[i+1,ns] = Q_td_s[i, ns] * (1. - D)   # Second step forgetting.

            Q_td_f[i+1,c] = (1. - alpha) * Q_td_f[i,c] + \
                            alpha * (Q_td_s[i,s] + lambd * (o - Q_td_s[i,s])) # First step TD update.
      
            Q_td_s[i+1,s] = (1. - alpha) * Q_td_s[i,s] +  alpha * o           # Second step TD update.

            # Update transition probabilities.

            trans_probs[i+1,nc] = trans_probs[i,nc] - tdec * (trans_probs[i,nc] - 0.5)  # Transition prob. forgetting.
            state_prediction_error = (s == 0) - trans_probs[i,c]
            trans_probs[i+1,c] = trans_probs[i,c] + tlr * state_prediction_error         # Transition prob. update.

            # Update Arbitration.

            arb[i + 1] = arb[i] + alr * (abs(state_prediction_error) - arb[i])

        # Evaluate choice probabilities and likelihood. 

        Q_mb = trans_probs * np.tile(Q_td_s[:,0],[2,1]).T + \
                (1. - trans_probs) * np.tile(Q_td_s[:,1],[2,1]).T # Model based action values. 

        W_arb = np.tile(ru.sigmoid(W - A * arb),[2,1]).T  # Trial by trial model basedness.

        Q_net = W_arb * Q_mb + (1. - W_arb) * Q_td_f # Mixture of model based and model free values.

        if self.use_kernels:
            Q_net[:,0] += kernel_Qs(session, bias, CK, SSK)

        choice_probs = ru.array_softmax(Q_net, iTemp)
        trial_log_likelihood = ru.protected_log(choice_probs[np.arange(n_trials), choices])
        session_log_likelihood = np.sum(trial_log_likelihood)

        if False:#return_trial_data:
            return {'Q_net'       : Q_net[:-1,:],  # Action values
                    'Q_td'        : Q_td_f[:-1,:],
                    'Q_mb'        : Q_mb[:-1,:],
                    'W_arb'       : W_arb[:-1,:],
                    'P_net'       : iTemp *           (Q_net[:-1,1] - (Q_net[:-1,0] + bias)), # Preferences.
                    'P_td'        : iTemp * (1 - W) * (Q_td_f [:-1,1]  - Q_td_f [:-1,0]),
                    'P_mb'        : iTemp * W *       (Q_mb [:-1,1]  - Q_mb [:-1,0]),
                    'P_k'         : - kernel_Qs(session, 0., CK, SSK)[:-1],
                    'choice_probs': choice_probs}
        else:
            return session_log_likelihood