def decompose(bin_spk_i, fa_dict):
	T = bin_spk_i.shape[0]
	mn = np.tile(fa_dict['fa_mu'], [1, T])
	dmn = bin_spk_i.T - mn

	main_shar = (fa_dict['fa_main_shared'] * dmn) 
	main_priv = (dmn - main_shar)

	shar = (fa_dict['fa_sharL'] * dmn)
	priv = (dmn - shar)
	input_type_dict = {}
	#input_type_dict['shared'] = sdf.rebin_spks(shar + mn)
	#input_type_dict['private'] = sdf.rebin_spks(priv + mn)
	input_type_dict['main_shared'] = sdf.rebin_spks(main_shar + mn)
	input_type_dict['main_private'] = sdf.rebin_spks(main_priv + mn)
	input_type_dict['all'] = sdf.rebin_spks(bin_spk_i.T)
	FA = fa_dict['FA_model']

	z = FA.transform(dmn.T)
	z = z.T #Transform to fact x 1
	z_mn = z[:fa_dict['fa_main_shar_n_dim'], :]

	input_type_dict['split_shar_z'] = sdf.rebin_spks(np.vstack((shar, priv)))
	input_type_dict['split'] = sdf.rebin_spks((np.vstack((z_mn, priv))))

	return input_type_dict
def analyze_OFC_gain(decs_all, dec_name='dmn'):
    qr_rat = np.array([1., 10e2, 10e4, 10e6, 10e8, 10e10])
    decoder = decs_all[dec_name]
    kin = decoder.kin
    dec_adj = dict()
    kin_adj = np.zeros_like(kin)
    kin_adj[[0, 2],:] = kin[[0, 2], :]
    for i, qr in enumerate(qr_rat):
        input_device = sim_fa_decoding.EndPtAssisterOFC(cost_err_ratio=qr)
        for k in range(kin.shape[1]-1):
            cs = kin[:, k][:, np.newaxis]
            ts = np.vstack((decoder.target[k,:][:, np.newaxis], np.zeros((3, 1)), np.ones((1,1))))

            ctrl = input_device.calc_next_state(np.vstack((cs, np.ones((1,1)))), ts)
            kin_adj[[3, 5], k+1] = np.squeeze(np.array(ctrl[[3, 5], 0]))

        dec_adj[qr] = train.train_KFDecoder_abstract(ssm, kin_adj, demean, units, 0.1)

    trbt_dec = {}
    for d, qr_ in enumerate(dec_adj.keys()):
        dec = trbt.RerunDecoding(hdf, dec_adj[qr_], task='bmi_resetting')
        dec.add_input(sdf.rebin_spks(demean), 'demean')
        trbt_dec[qr_] = dec

    for iq, q in enumerate(trbt_dec.keys()):
        d = trbt_dec[q]
        plot_sim_traj.traj_sem_plot(d, 'demean', it_start=epoch1_ix, extra_title_text=str(q))
示例#3
0
def decompose_inputs(fa_dict, bin_spk_i, hdf, dec, task='bmi_resetting', process_to_ix=None, use_main = True):

    #Main shared: 

    T = bin_spk_i.shape[0]
    mn = np.tile(fa_dict['fa_mu'], [1, T])
    dmn = bin_spk_i.T - mn

    ReSim = trbt.RerunDecoding(hdf, dec, task='bmi_resetting')

    if use_main:
        shar = (fa_dict['fa_main_shared'] * dmn) 
        priv = (dmn - shar)
        ReSim.main= True
    else:
        shar = (fa_dict['fa_sharL'] * dmn)
        priv = (dmn - shar)
        ReSim.main = False

    main_shar_spks = sdf.rebin_spks(shar + mn)
    main_priv_spks = sdf.rebin_spks(priv + mn)

    if process_to_ix is None:
        proc_ix = main_shar_spks.shape[0]
    else:
        proc_ix = process_to_ix

    
    ReSim.add_input(ReSim.spike_counts[:proc_ix, :, :], 'all')
    ReSim.add_input(main_shar_spks[:proc_ix, :, :], 'main_shar')
    ReSim.add_input(main_priv_spks[:proc_ix, :, :], 'main_priv')

    #Now plot trials: 
    gr_ix = np.array([(hdf.root.task_msgs[j-3]['time'], i['time']) for j, i 
        in enumerate(hdf.root.task_msgs) if i['msg'] == 'reward'])

    targ_pos = hdf.root.task[gr_ix[:50,1]]['target']
    targ_ix = pa.get_target_ix(targ_pos[:,[0, 2]])
    ReSim.target_ix = targ_ix
    ReSim.targ_pos_trunc = targ_pos
    return ReSim, gr_ix, proc_ix
示例#4
0
def psth(Recoder,bins_after_go=60*4, input_type='all'):
    hdf = Recoder.hdf

    #Go indices
    go_ix = np.array([hdf.root.task_msgs[it-3][1] for it, t in enumerate(hdf.root.task_msgs[:]) if t[0] == 'reward'])
    go_ix = np.hstack((go_ix, len(Recoder.decoded_pos[input_type]) ))
    spks = sdf.rebin_spks(Recoder.dec_spk_cnt_bin[input_type])

    targ_ix = pa.get_target_ix(Recoder.target[:, [0,2]])
    psth = dict()

    for ig, go in enumerate(go_ix[:-1]):
        ti = targ_ix[go]
        if go+bins_after_go < len(Recoder.decoded_pos[input_type]):
            sp = np.squeeze(spks[go:go+bins_after_go, : , 0])

            if int(ti) in psth.keys():
                psth[int(ti)] = np.dstack((psth[int(ti)], sp))

            else:
                psth[int(ti)] = sp
    return psth
def run_sim(decoder_full, decoder_shar, decoder_sc_shar, bin_spk_cnts2, epoch1_ix, epoch2_ix, update_bmi_ix, FA_dict):
    dec = dict()
    dec['R_demn'] = trbt.RerunDecoding(hdf, decoder_demn, task='bmi_resetting')
    dec['R_shar'] = trbt.RerunDecoding(hdf, decoder_shar, task='bmi_resetting')
    dec['R_sc_shar'] = trbt.RerunDecoding(hdf, decoder_sc_shar, task='bmi_resetting')
    #dec['orig'] = trbt.RerunDecoding(hdf, d_ec, task='bmi_resetting')

    T = bin_spk_cnts2.shape[0]
    inputs = dict()
    demean = bin_spk_cnts2.T - np.tile(FA_dict['fa_mu'], [1, T])
    inputs['demn'] = demean
    inputs['main_shar'] = FA_dict['fa_main_shared'] * demean
    main_priv = demean - inputs['main_shar']
    inputs['main_sc_shar'] = np.multiply(inputs['main_shar'], np.tile(FA_dict['fa_main_shared_sc'], [1, T]))
    inputs['full_sc'] = np.multiply(demean, np.tile(FA_dict['fa_main_shared_sc'], [1,T]))
    inputs['main_sc_shar_pls_priv'] = inputs['main_sc_shar'] + main_priv

    for i, (k, v) in enumerate(inputs.items()):
        for d, tp in enumerate(dec.keys()):
            if tp!='R_demn':
                inp_ = v
                dec[tp].add_input(sdf.rebin_spks(inp_), k)
    return dec
示例#6
0
	shar[i, :, :] = dim_red_te['fa_sharL']*dmn[i, :, :]

main_shar = np.zeros_like(spike_counts)
for i in range(T):
	main_shar[i, :, :] = dim_red_te['fa_main_shared']*dmn[i, :, :]

shar_var = np.array(dim_red_te['fa_shar_var_sc'])
main_shar_var = np.array(dim_red_te['fa_main_shared_sc'])
T_shar_var = np.tile( shar_var[np.newaxis, :, :], [T, 1, 1])
T_main_shar_var = np.tile( main_shar_var[np.newaxis, :, :], [T, 1, 1])

shar_sc_input = np.multiply(shar, T_shar_var) + T_mn
main_shar_sc_input = np.multiply(main_shar, T_main_shar_var) + T_mn

#Unbin spikes: 
import sim_decoding_FA as sdf
all_ = sdf.rebin_spks(spike_counts[:,:,0].T)
shar_ = sdf.rebin_spks(shar_sc_input[:,:, 0].T)
main_shar_ = sdf.rebin_spks(main_shar_sc_input[:,:, 0].T)


#Generate Scaled Main Shared: 
Recoder.add_input(all_[:20000,:,:], 'all_proc')
plot_sim_traj.plot_traj(Recoder, input_type='all_proc')

Recoder.add_input(shar_[:20000,:,:], 'sh_sc')
plot_sim_traj.plot_traj(Recoder, input_type='sh_sc')

Recoder.add_input(main_shar_[:20000,:,:], 'main_sh_sc')
plot_sim_traj.plot_traj(Recoder, input_type='main_sh_sc')
def main(hdf=None, dec=None, mesh_plot=False):
    if hdf is None:
        #from db import dbfunctions as dbfn
        fa_te = dbfn.TaskEntry(4038)
        fa_hdf = fa_te.hdf
    else:
        fa_hdf = hdf

    cursor_pos = fa_hdf.root.task[:]['cursor'][:, [0, 2]]
    d_cursor = np.diff(cursor_pos, axis=0)
    update_ix = np.nonzero(np.diff(cursor_pos[:,0]))[0]

    #Trial epochs: 
    epoch_ix = np.array([ [fa_hdf.root.task_msgs[j-3]['time'], i['time']]
        for j, i in enumerate(fa_hdf.root.task_msgs[:]) if i['msg']=='reward'])

    ixx = []
    for i, (g,r) in enumerate(epoch_ix):
        u_ix = np.nonzero(np.logical_and(update_ix<=r, update_ix>g))[0]
        ixx.append(update_ix[u_ix])
    IX = np.hstack((ixx))

    #Metrics
    cursor_vel = d_cursor[IX,:]
    input_dict_zsc = {}

    if hdf is None: #Old version:
        input_types = ['all','shared','private','shared_sc','private_sc']
        input_type_dict = {}
        for i in input_types:
            input_type_dict[i] = fa_hdf.root.task[:][i+'_input']
        fig_map = dict(all=[0, 2], shared=[0, 0], private=[0, 1], shared_sc=[1, 0], private_sc=[1, 1])

    else:
        input_types = ['all','shared','private','main_shared','main_private']
        fa_dict = factor_analysis_tasks.FactorBMIBase.generate_FA_matrices(None, hdf=fa_hdf, dec=dec)
        input_type_dict = {}

        #Calculate spikes to transform:
        drives_neurons_ix0 = 3
        internal_state = fa_hdf.root.task[:]['internal_decoder_state']
        update_bmi_ix = np.nonzero(np.diff(np.squeeze(internal_state[:, drives_neurons_ix0, 0])))[0]+1
        spike_i = fa_hdf.root.task[:]['spike_counts'][:,:,0]
        bin_spk_i = np.zeros((len(update_bmi_ix), spike_i.shape[1]))

        #binned spikes: 
        for ib, i_ix in enumerate(update_bmi_ix):
            bin_spk_i[ib, :]= np.sum(spike_i[i_ix-5:i_ix+1, :], axis=0)

        #Decompose:
        T = bin_spk_i.shape[0]
        mn = np.tile(fa_dict['fa_mu'], [1, T])
        dmn = bin_spk_i.T - mn

        main_shar = (fa_dict['fa_main_shared'] * dmn) 
        main_priv = (dmn - main_shar)
        
        shar = (fa_dict['fa_sharL'] * dmn)
        priv = (dmn - shar)
        
        input_type_dict['shared'] = sdf.rebin_spks(shar + mn)
        input_type_dict['private'] = sdf.rebin_spks(priv + mn)
        input_type_dict['main_shared'] = sdf.rebin_spks(main_shar + mn)
        input_type_dict['main_private'] = sdf.rebin_spks(main_priv + mn)
        input_type_dict['all'] = sdf.rebin_spks(bin_spk_i.T)
        fig_map = dict(all=[0, 2], shared=[0, 0], private=[0, 1], main_shared=[1, 0], main_private=[1, 1])

    for input_ in input_type_dict.keys():

        inp = input_type_dict[input_]
        inp[np.isnan(inp)] = 0

        T = len(IX)
        mFR = np.mean(inp, axis=0)
        sdFR= np.std(inp, axis=0)
        input_dict_zsc[input_] = (inp[IX+1] - np.tile(mFR, [T, 1, 1])) / np.tile(sdFR, [T, 1, 1])

    n_units = input_dict_zsc[input_].shape[1]

    # For each unit make a contour plot

    for unit in range(n_units):
        f, ax = plt.subplots(nrows = 2, ncols = 3)

        for input_ in input_type_dict.keys():
            axi = ax[fig_map[input_][0], fig_map[input_][1]]

            X = input_dict_zsc[input_][:, unit, :]

            if ((fig_map[input_][0]==1) and (fig_map[input_][1] == 1)):
                bar = True
            else:
                bar = False

            if mesh_plot:
                plot_cont(axi, X, cursor_vel, bar=False, mesh_plot=True)
            else:
                plot_cont(axi, X, cursor_vel, bar=bar)
            axi.set_title(input_)

        fname = os.path.expandvars('$FA_GROM_DATA/../online_analysis/unit_tuning_4549/un'+str(unit)+'.png')
        f.savefig(fname)