def performance(spikes, pars, stimulus): """ Calculate the response of a model neuron, given a set of filters and a stimulus. Also calculate the cross-correlation with the real spikes as a measure of performance. """ k, Q, mu = gqm.splitpars(pars) firing_rate = gqm.gqm_neuron(k, Q, mu, st.frame_duration)(stimulus) cross_corr = np.corrcoef(spikes, firing_rate)[0, 1] return cross_corr, firing_rate
sp_tr, sp_te, stim_tr, stim_te = train_test_split(spikes, stimulus, test_size=val_split_size, split_pos=val_split_pos) res = gqm.minimize_loglikelihood(np.zeros((stimdim, fl)), np.zeros((stimdim, fl, fl)), 0, stim_tr, st.frame_duration, sp_tr, minimize_disp=True, method='BFGS') elapsed = time.time()-start print(f'Time elapsed: {elapsed/60:6.1f} mins for cell {i}') k_out, Q_out, mu_out = gqm.splitpars(res.x) kall[i, ...] = k_out Qall[i, ...] = Q_out muall[i] = mu_out firing_rate = gqm.gqm_neuron(k_out, Q_out, mu_out, st.frame_duration)(stim_te) cross_corr = np.corrcoef(sp_te, firing_rate)[0, 1] cross_corrs[i] = cross_corr #%% fig, axes = plt.subplots(stimdim, 5, figsize=(15, 5)) plt.rc('font', size=8) for j in range(stimdim): axk = axes[j, 0] if j <= 1: axk.plot(t, sta[j], color='grey', label='STA') axk.plot(t, k_out[j, ...], label='k (GQM)')
np.zeros((stimdim, st.filter_length, st.filter_length)), 0, stimulus, st.frame_duration, spikes, method='BFGS', callback=optim_tracker) elapsed = time.time() - start print(f'Time elapsed: {elapsed/60:6.1f} mins for cell {i}') all_pars_progress.append(pars_progress) all_res.append(res) plt.plot(gqm.splitpars(res.x)[0].T) plt.title('Final linear filters') plt.show() spikes = st.binnedspiketimes(i) if res.nit > 1000: break cc_progress = np.zeros(res.nit) for j, pars in enumerate(pars_progress): cc_progress[j], fr = performance(spikes, pars, stimulus) # plt.plot(ki.T) # plt.title(j) # plt.show()