def analyze_OFC_gain(decs_all, dec_name='dmn'):
    qr_rat = np.array([1., 10e2, 10e4, 10e6, 10e8, 10e10])
    decoder = decs_all[dec_name]
    kin = decoder.kin
    dec_adj = dict()
    kin_adj = np.zeros_like(kin)
    kin_adj[[0, 2],:] = kin[[0, 2], :]
    for i, qr in enumerate(qr_rat):
        input_device = sim_fa_decoding.EndPtAssisterOFC(cost_err_ratio=qr)
        for k in range(kin.shape[1]-1):
            cs = kin[:, k][:, np.newaxis]
            ts = np.vstack((decoder.target[k,:][:, np.newaxis], np.zeros((3, 1)), np.ones((1,1))))

            ctrl = input_device.calc_next_state(np.vstack((cs, np.ones((1,1)))), ts)
            kin_adj[[3, 5], k+1] = np.squeeze(np.array(ctrl[[3, 5], 0]))

        dec_adj[qr] = train.train_KFDecoder_abstract(ssm, kin_adj, demean, units, 0.1)

    trbt_dec = {}
    for d, qr_ in enumerate(dec_adj.keys()):
        dec = trbt.RerunDecoding(hdf, dec_adj[qr_], task='bmi_resetting')
        dec.add_input(sdf.rebin_spks(demean), 'demean')
        trbt_dec[qr_] = dec

    for iq, q in enumerate(trbt_dec.keys()):
        d = trbt_dec[q]
        plot_sim_traj.traj_sem_plot(d, 'demean', it_start=epoch1_ix, extra_title_text=str(q))
def plot_stuff(dec):

    for r in dec:
        R = dec[r]
        for i in R.decoded_pos.keys():
            pos = R.decoded_pos[i]
            plot_sim_traj.traj_sem_plot(R, i, it_start=epoch1_ix, extra_title_text=r+' decoder')
def plot_metrics(dec, inputs):
    ''' Plot magnitude of Kalman Gain in Velocity with variance of neural input for each neuron, for 
    all types of input
    '''

    pairs_to_plot = [('demn', 'R_demn'), 
                     ('main_sc_shar', 'R_sc_shar'), 
                     ('main_shar', 'R_shar')]
    # pairs_to_plot = [('demn', 'R_demn'), 
    #                  ('main_sc_shar', 'R_demn'), 
    #                  ('main_shar', 'R_demn')]

    cmap = ['r', 'g', 'b', 'm', 'k']

    f, ax = plt.subplots()
    for ii, k in enumerate(pairs_to_plot):
        v = inputs[k[0]]
        var = np.var(v, axis=1)
        var = np.squeeze(np.array(var))
            
        R = dec[k[1]]
        F, K = R.dec.filt.get_sskf()
        vel_ = []
        for j, i in enumerate(var):
            vel_.append(np.linalg.norm(K[[3, 5], j]*var[j]))
        ax.plot(np.squeeze(np.array(var)), vel_, '.',color=cmap[ii], label='Input: '+k[0]+' Dec: '+k[1])
        slp, int_, cc, pv, std = ss.linregress(var, vel_)
        vel_star = int_+slp*var
        ax.plot(var, vel_star, '-',color=cmap[ii], label='slope: '+str(slp))

        #PLot traj map: 
        plot_sim_traj.traj_sem_plot(R, k[0], it_start=epoch1_ix, extra_title_text=k[1]+' decoder')

    ax.set_xlabel('Variance of Neuron')
    ax.set_ylabel('Norm of Vel. from K_g*Var')
    ax.set_ylim([0., 1.2])
    plt.legend()
Esempio n. 4
0
def test_xz_KF(kf_decoder, dim_red_dict, hdf):
    f, ax = plt.subplots(nrows=1, ncols=3)

    all_spike_counts = hdf.root.task[:]['spike_counts']

    Recoder = trbt.RerunDecoding(hdf, kf_decoder, task='bmi_multi')
    Recoder.add_input(all_spike_counts, 'all')
    plot_sim_traj.traj_sem_plot(Recoder, 'all', it_start = 0, it_cutoff = len(hdf.root.task), 
        ax=ax[0], rm_assist=False)

    #Get scaled shared
    T = all_spike_counts.shape[0]
    tile_mn = np.tile(dim_red_dict['fa_mu'][np.newaxis,:,0], [T, 1])
    tile_mn_1 = tile_mn[:,:,np.newaxis]
    dmn = all_spike_counts[:,:,0]  - tile_mn
    shar = dim_red_dict['fa_sharL']*(dmn.T)
    priv = dmn - shar.T

    sc_shar = np.multiply(shar, np.tile(dim_red_dict['fa_shar_var_sc'], [1, T])).T
    sc_shar = np.array(sc_shar[:,:,np.newaxis]) + np.array(tile_mn_1)
    Recoder.add_input(sc_shar,'shar_sc')
    plot_sim_traj.traj_sem_plot(Recoder, 'shar_sc', it_start = 0, it_cutoff = len(hdf.root.task),
        ax=ax[1], rm_assist=False)

    #Scaled shared + priv:
    # priv_1 = np.array(priv[:,:,np.newaxis])
    # sc_shar_pls_priv = sc_shar + priv_1
    # Recoder.add_input(sc_shar_pls_priv, 'shar_sc_pls_priv')
    # plot_sim_traj.traj_sem_plot(Recoder, 'shar_sc_pls_priv', it_start = 0, it_cutoff = len(hdf.root.task), 
    #   ax=ax[2], rm_assist=False)
    main_shar = (dmn*dim_red_dict['fa_main_shared'])
    sc_fact = np.tile(dim_red_dict['fa_main_shared_sc'].T, [T, 1])
    main_sc_shar = np.multiply(main_shar, sc_fact) + tile_mn
    Recoder.add_input(main_sc_shar[:, :, np.newaxis], 'main_sc_shar')
    plot_sim_traj.traj_sem_plot(Recoder, 'main_sc_shar', it_start = 0, it_cutoff = len(hdf.root.task), 
        ax=ax[2], rm_assist=False)

    return ax, Recoder
for i_ in input_type_dict_obs.keys():
	if i_ is 'split': 
		#print 'skip split'
	# 	ReSim_split.add_input(input_type_dict[i_][:ix, :, :], input_type=i_)
		ReSim_obs_split.add_input(input_type_dict_obs[i_][:ix, :, :], input_type=i_)
	elif i_ is 'split_shar_z':
		print 'skip split - shar_z'
	# 	ReSim_split_shar_z.add_input(input_type_dict[i_][:ix, :, :], input_type=i_)
	# elif i_ is 'all':
	# 	print 'skip all'
	else:
		#ReSim.add_input(input_type_dict[i_][:ix, :, :], input_type=i_)
		ReSim_obs.add_input(input_type_dict_obs[i_][:ix, :, :], input_type=i_)

import plot_sim_traj as pst
pst.traj_sem_plot(ReSim, 'all', it_cutoff=60000)
pst.traj_sem_plot(ReSim_split_shar_z, 'split_shar_z', it_cutoff=60000)
pst.traj_sem_plot(ReSim_split, 'split', it_cutoff=60000)
pst.traj_sem_plot(ReSim, 'main_shared', it_cutoff=60000)
pst.traj_sem_plot(ReSim, 'main_private', it_cutoff=60000)

pst.traj_sem_plot(ReSim_obs, 'all', it_cutoff=29000)
pst.traj_sem_plot(ReSim_obs_split, 'split', it_cutoff=29000)
pst.traj_sem_plot(ReSim_obs, 'main_shared', it_cutoff=29000)
pst.traj_sem_plot(ReSim_obs, 'main_private', it_cutoff=29000)

#Decompose:
def decompose(bin_spk_i, fa_dict):
	T = bin_spk_i.shape[0]
	mn = np.tile(fa_dict['fa_mu'], [1, T])
	dmn = bin_spk_i.T - mn