def session_likelihood(self, session, params_T):#, return_trial_data = False): # Unpack trial events. choices, second_steps, outcomes = ut.CTSO_unpack(session.CTSO, 'CSO') # Unpack parameters. alpha, iTemp, lambd, W, tlr, D, tdec, A, alr = params_T[:9] # Learning rate for arbitration. if self.use_kernels: bias, CK, SSK = params_T[-3:] #Variables. n_trials = len(choices) Q_td_f = np.zeros([n_trials + 1 , 2]) # Model free first step action values (low, high). Q_td_s = np.zeros([n_trials + 1 , 2]) # Model free second step action values (right, left). arb = np.zeros(n_trials + 1) # Arbitration parameter, positive means more model based. trans_probs = np.zeros([n_trials + 1 , 2]) # Transition probabilities for low and high pokes. trans_probs[0,:] = 0.5 # Initialize first trial transition probabilities. for i, (c, s, o) in enumerate(zip(choices, second_steps, outcomes)): # loop over trials. nc = 1 - c # Action not chosen at first step. ns = 1 - s # State not reached at second step. # Update model free action values. Q_td_f[i+1,nc] = Q_td_f[i, nc] * (1. - D) # First step forgetting. Q_td_s[i+1,ns] = Q_td_s[i, ns] * (1. - D) # Second step forgetting. Q_td_f[i+1,c] = (1. - alpha) * Q_td_f[i,c] + \ alpha * (Q_td_s[i,s] + lambd * (o - Q_td_s[i,s])) # First step TD update. Q_td_s[i+1,s] = (1. - alpha) * Q_td_s[i,s] + alpha * o # Second step TD update. # Update transition probabilities. trans_probs[i+1,nc] = trans_probs[i,nc] - tdec * (trans_probs[i,nc] - 0.5) # Transition prob. forgetting. state_prediction_error = (s == 0) - trans_probs[i,c] trans_probs[i+1,c] = trans_probs[i,c] + tlr * state_prediction_error # Transition prob. update. # Update Arbitration. arb[i + 1] = arb[i] + alr * (abs(state_prediction_error) - arb[i]) # Evaluate choice probabilities and likelihood. Q_mb = trans_probs * np.tile(Q_td_s[:,0],[2,1]).T + \ (1. - trans_probs) * np.tile(Q_td_s[:,1],[2,1]).T # Model based action values. W_arb = np.tile(ru.sigmoid(W - A * arb),[2,1]).T # Trial by trial model basedness. Q_net = W_arb * Q_mb + (1. - W_arb) * Q_td_f # Mixture of model based and model free values. if self.use_kernels: Q_net[:,0] += kernel_Qs(session, bias, CK, SSK) choice_probs = ru.array_softmax(Q_net, iTemp) trial_log_likelihood = ru.protected_log(choice_probs[np.arange(n_trials), choices]) session_log_likelihood = np.sum(trial_log_likelihood) if False:#return_trial_data: return {'Q_net' : Q_net[:-1,:], # Action values 'Q_td' : Q_td_f[:-1,:], 'Q_mb' : Q_mb[:-1,:], 'W_arb' : W_arb[:-1,:], 'P_net' : iTemp * (Q_net[:-1,1] - (Q_net[:-1,0] + bias)), # Preferences. 'P_td' : iTemp * (1 - W) * (Q_td_f [:-1,1] - Q_td_f [:-1,0]), 'P_mb' : iTemp * W * (Q_mb [:-1,1] - Q_mb [:-1,0]), 'P_k' : - kernel_Qs(session, 0., CK, SSK)[:-1], 'choice_probs': choice_probs} else: return session_log_likelihood