def get_vfb_spk(task_obj, rew_ix, units): go_ix = np.array( [ i["time"] for j, i in enumerate(task_obj.hdf.root.task_msgs[:-3]) if task_obj.hdf.root.task_msgs[j + 3]["time"] in rew_ix ] ) tslices = zip(go_ix, rew_ix) files = dict(hdf=task_obj.hdf, plexon=task_obj.plx_filename) source = "task" tmask, rows = train._get_tmask_plexon(task_obj.plx, (0, rew_ix[-1] + 100), sys_name=source) # Extractor: extractor_fn = extractor.BinnedSpikeCountsExtractor.extract_from_file extractor_kwargs = dict(units=units) bin_spk = [] for tslice in tslices: neurows = rows[tslice[0] : tslice[1]] neural_features, un, ext = extractor_fn(files, neurows, 0.1, units, extractor_kwargs) if len(un) != len(units): for i, (u0, u1) in enumerate(units): if not np.logical_and(un[i, 0] == u0, un[i, 1] == u1): un = np.vstack((un[:i, :], np.array([u0, u1]), un[i:, :])) tm = neural_features.shape[0] neural_features = np.hstack((neural_features[:, :i], np.zeros((tm, 1)), neural_features[:, i:])) bin_spk.append(neural_features) return np.vstack((bin_spk))
def calc_overlap(baseline_te=None, te_w_decoder=None, decoder=None): ''' Summary: Method to calculate overlap between Input param: baseline_te: task entry number to retrive baseline spike counts Input param: te_w_decoder: task entry number with decoder to get e1, e2 inds Output param: ''' #Get decoder & baseline task entry: if decoder is None: te = dbfn.TaskEntry(te_w_decoder) decoder = te.decoder e1_inds = decoder.filt.e1_inds e2_inds = decoder.filt.e2_inds #Get baseline plexon file b_te = dbfn.TaskEntry(baseline_te) b_plx = b_te.plx #Define plx extraction params tslice = (0., b_te.length) tmask, neurows = train._get_tmask_plexon(b_plx, tslice, sys_name='task') strobe_rate = 60. binlen = .1 step = int(binlen/(1./strobe_rate)) # Downsample kinematic data according to decoder bin length (assumes non-overlapping bins) interp_rows = neurows[::step] #Only extract units used in decoder: units = decoder.units units_dec_e1 = units[e1_inds] units_dec_e2 = units[e2_inds] #Extract units: from plexon import psth spike_bin_fn_e1 = psth.SpikeBin(units_dec_e1, binlen) spike_bin_fn_e2 = psth.SpikeBin(units_dec_e2, binlen) spike_counts_e1 = np.array(list(b_plx.spikes.bin(interp_rows, spike_bin_fn_e1))) spike_counts_e2 = np.array(list(b_plx.spikes.bin(interp_rows, spike_bin_fn_e2))) SC = np.hstack((spike_counts_e1, spike_counts_e2)) #Get covariance of units: baseline_fa = DummyFAClass(SC) #Get decoder cov: dec = np.ones((len(e1_inds)+len(e2_inds), 10)) dec[:len(e1_inds), :] /= 10. dec[len(e1_inds):, :] /= -10. dec_fa = DummyFAClass(dec.T) OV = subspace_overlap.get_overlap(baseline_fa, dec_fa) OV2 = subspace_overlap.get_overlap(dec_fa, baseline_fa) return OV, OV2