def nan_invalid_segments(rec): """ Currently a specialized signal for removing incorrect trials from data collected using baphy during behavior. TODO: Migrate to nems_db or make a more generic version """ # First, select the appropriate subset of data rec['resp'] = rec['resp'].rasterize() sig = rec['resp'] # get list of start and stop times (epoch bounds) epoch_indices = np.vstack( (ep.epoch_intersection(sig.get_epoch_bounds('HIT_TRIAL'), sig.get_epoch_bounds('REFERENCE')), ep.epoch_intersection(sig.get_epoch_bounds('REFERENCE'), sig.get_epoch_bounds('PASSIVE_EXPERIMENT')))) # Only takes the first of any conflicts (don't think I actually need this) epoch_indices = ep.remove_overlap(epoch_indices) epoch_indices2 = epoch_indices[0:1, :] for i in range(1, epoch_indices.shape[0]): if epoch_indices[i, 0] == epoch_indices2[-1, 1]: epoch_indices2[-1, 1] = epoch_indices[i, 0] else: epoch_indices2 = np.concatenate( (epoch_indices2, epoch_indices[i:(i + 1), :]), axis=0) # add adjusted signals to the recording newrec = rec.nan_times(epoch_indices2) return newrec
def select_times(recording, subset, random_only=True, dual_only=True): ''' Parameters ---------- recording : nems.recording.Recording The recording object. subset : Nx2 array Epochs representing the selected subset (e.g., from an est/val split). random_only : bool If True, return only the repeating portion of the subset dual_only : bool If True, return only the dual stream portion of the subset ''' epochs = recording['stim'].epochs m_dual = epochs['name'] == 'dual' m_repeating = epochs['name'] == 'repeating' m_trial = epochs['name'] == 'TRIAL' dual_epochs = epochs.loc[m_dual, ['start', 'end']].values repeating_epochs = epochs.loc[m_repeating, ['start', 'end']].values trial_epochs = epochs.loc[m_trial, ['start', 'end']].values if random_only: subset = epoch_difference(subset, repeating_epochs) if dual_only: subset = epoch_intersection(subset, dual_epochs) return recording.select_times(subset)
def test_intersection_bug(): # Copied from real data a = np.array([ [57.7, 61.4], [61.4, 66.6], [66.6, 70.9], [70.9, 75.80000000000001], [75.8, 81.0], ]) b = np.array([ [57.7000000000001, 61.4000000000001], [61.4000000000001, 66.6000000000001], [66.6000000000001, 70.9], [75.8000000000001, 81.0000000000001], [81.0000000000001, 85.6000000000001], ]) expected = np.array([ [57.7, 61.4], [61.4, 66.6], [66.6, 70.9], [75.8, 81.0], ]) result = epoch_intersection(a, b) print(result) assert np.all(result == expected)
def test_intersection(epoch_a, epoch_b): expected = np.array([ [60, 70], [75, 76], [90, 95], ]) result = epoch_intersection(epoch_a, epoch_b) assert np.all(result == expected)
def test_empty_intersection(): a = np.array([[0, 20], [50, 70]]) b = np.array([[25, 30], [30, 45]]) with pytest.warns(RuntimeWarning): result = epoch_intersection(a, b) print(result)
def test_empty_intersection(): a = np.array([[0, 20], [50, 70]]) b = np.array([[25, 30], [30, 45]]) with pytest.raises(RuntimeWarning, message="Expected RuntimeWarning for size 0"): result = epoch_intersection(a, b)
def remove_invalid_segments(rec): """ Currently a specialized function for removing incorrect trials from data collected using baphy during behavior. TODO: Migrate to nems_lbhb or make a more generic version """ # First, select the appropriate subset of data rec['resp'] = rec['resp'].rasterize() if 'stim' in rec.signals.keys(): rec['stim'] = rec['stim'].rasterize() sig = rec['resp'] # get list of start and stop indices (epoch bounds) epoch_indices = np.vstack( (ep.epoch_intersection(sig.get_epoch_indices('REFERENCE'), sig.get_epoch_indices('HIT_TRIAL')), ep.epoch_intersection(sig.get_epoch_indices('REFERENCE'), sig.get_epoch_indices('PASSIVE_EXPERIMENT')))) # Only takes the first of any conflicts (don't think I actually need this) epoch_indices = ep.remove_overlap(epoch_indices) # merge any epochs that are directly adjacent epoch_indices2 = epoch_indices[0:1] for i in range(1, epoch_indices.shape[0]): if epoch_indices[i, 0] == epoch_indices2[-1, 1]: epoch_indices2[-1, 1] = epoch_indices[i, 1] else: #epoch_indices2 = np.concatenate( # (epoch_indices2, epoch_indices[i:(i + 1), :]), axis=0) epoch_indices2 = np.append(epoch_indices2, epoch_indices[i:(i + 1)], axis=0) # convert back to times epoch_times = epoch_indices2 / sig.fs # add adjusted signals to the recording newrec = rec.select_times(epoch_times) return newrec
def make_state_signal(rec, state_signals=['pupil'], permute_signals=[], new_signalname='state'): """ generate state signal for stategain.S/sdexp.S models valid state signals include (incomplete list): pupil, pupil_ev, pupil_bs, pupil_psd active, each_file, each_passive, each_half far, hit, lick, p_x_a TODO: Migrate to nems_lbhb or make a more generic version """ newrec = rec.copy() resp = newrec['resp'].rasterize() # normalize mean/std of pupil trace if being used if ('pupil' in state_signals) or ('pupil_ev' in state_signals) or \ ('pupil_bs' in state_signals): # normalize min-max p = newrec["pupil"].as_continuous().copy() # p[p < np.nanmax(p)/5] = np.nanmax(p)/5 p -= np.nanmean(p) p /= np.nanstd(p) newrec["pupil"] = newrec["pupil"]._modified_copy(p) if ('pupil_psd') in state_signals: pup = newrec['pupil'].as_continuous().copy() fs = newrec['pupil'].fs # get spectrogram of pupil nperseg = int(60 * fs) noverlap = nperseg - 1 f, time, Sxx = ss.spectrogram(pup.squeeze(), fs=fs, nperseg=nperseg, noverlap=noverlap) max_chan = 4 # (np.abs(f - 0.1)).argmin() # Keep only first five channels of spectrogram #f = interpolate.interp1d(np.arange(0, Sxx.shape[1]), Sxx[:max_chan, :], axis=1) #newspec = f(np.linspace(0, Sxx.shape[-1]-1, pup.shape[-1])) pad1 = np.ones((max_chan, int(nperseg / 2))) * Sxx[:max_chan, [0]] pad2 = np.ones((max_chan, int(nperseg / 2 - 1))) * Sxx[:max_chan, [-1]] newspec = np.concatenate((pad1, Sxx[:max_chan, :], pad2), axis=1) # = np.concatenate((Sxx[:max_chan, :], np.tile(Sxx[:max_chan,-1][:, np.newaxis], [1, noverlap])), axis=1) newspec -= np.nanmean(newspec, axis=1, keepdims=True) newspec /= np.nanstd(newspec, axis=1, keepdims=True) spec_signal = newrec['pupil']._modified_copy(newspec) spec_signal.name = 'pupil_psd' chan_names = [] for chan in range(0, newspec.shape[0]): chan_names.append('puppsd{0}'.format(chan)) spec_signal.chans = chan_names newrec.add_signal(spec_signal) if ('pupil_ev' in state_signals) or ('pupil_bs' in state_signals): # generate separate pupil baseline and evoked signals prestimsilence = newrec["pupil"].extract_epoch('PreStimSilence') spont_bins = prestimsilence.shape[2] pupil_trial = newrec["pupil"].extract_epoch('TRIAL') pupil_bs = np.zeros(pupil_trial.shape) for ii in range(pupil_trial.shape[0]): pupil_bs[ii, :, :] = np.mean(pupil_trial[ii, :, :spont_bins]) pupil_ev = pupil_trial - pupil_bs newrec['pupil_ev'] = newrec["pupil"].replace_epoch('TRIAL', pupil_ev) newrec['pupil_ev'].chans = ['pupil_ev'] newrec['pupil_bs'] = newrec["pupil"].replace_epoch('TRIAL', pupil_bs) newrec['pupil_bs'].chans = ['pupil_bs'] if ('each_passive' in state_signals): file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_") pset = [] found_passive1 = False for f in file_epochs: # test if passive expt epoch_indices = ep.epoch_intersection( resp.get_epoch_indices(f), resp.get_epoch_indices('PASSIVE_EXPERIMENT')) if epoch_indices.size: if not (found_passive1): # skip first passive found_passive1 = True else: pset.append(f) newrec[f] = resp.epoch_to_signal(f) state_signals.remove('each_passive') state_signals.extend(pset) if 'each_passive' in permute_signals: permute_signals.remove('each_passive') permute_signals.extend(pset) if ('each_file' in state_signals): file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_") trial_indices = resp.get_epoch_indices('TRIAL') passive_indices = resp.get_epoch_indices('PASSIVE_EXPERIMENT') pset = [] pcount = 0 acount = 0 for f in file_epochs: # test if passive expt f_indices = resp.get_epoch_indices(f) epoch_indices = ep.epoch_intersection(f_indices, passive_indices) if epoch_indices.size: # this is a passive file name1 = "PASSIVE_{}".format(pcount) pcount += 1 if pcount == 1: acount = 1 # reset acount for actives after first passive else: # use first passive part A as baseline - don't model pset.append(name1) newrec[name1] = resp.epoch_to_signal(name1, indices=f_indices) else: name1 = "ACTIVE_{}".format(acount) pset.append(name1) newrec[name1] = resp.epoch_to_signal(name1, indices=f_indices) if pcount == 0: acount -= 1 else: acount += 1 # test if passive expt # epoch_indices = ep.epoch_intersection( # resp.get_epoch_indices(f), # resp.get_epoch_indices('PASSIVE_EXPERIMENT')) # if epoch_indices.size and not(found_passive1): # # skip first passive # found_passive1 = True # else: # pset.append(f) # newrec[f] = resp.epoch_to_signal(f) state_signals.remove('each_file') state_signals.extend(pset) if 'each_file' in permute_signals: permute_signals.remove('each_file') permute_signals.extend(pset) if ('each_half' in state_signals): file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_") trial_indices = resp.get_epoch_indices('TRIAL') passive_indices = resp.get_epoch_indices('PASSIVE_EXPERIMENT') pset = [] pcount = 0 acount = 0 for f in file_epochs: # test if passive expt f_indices = resp.get_epoch_indices(f) epoch_indices = ep.epoch_intersection(f_indices, passive_indices) trial_intersect = ep.epoch_intersection(f_indices, trial_indices) #trial_count = trial_intersect.shape[0] #_split = int(trial_count/2) _t1 = trial_intersect[0, 0] _t2 = trial_intersect[-1, 1] _split = int((_t1 + _t2) / 2) epoch1 = np.array([[_t1, _split]]) epoch2 = np.array([[_split, _t2]]) if epoch_indices.size: # this is a passive file name1 = "PASSIVE_{}_{}".format(pcount, 'A') name2 = "PASSIVE_{}_{}".format(pcount, 'B') pcount += 1 if pcount == 1: acount = 1 # reset acount for actives after first passive else: # don't model PASSIVE_0 A -- baseline pset.append(name1) newrec[name1] = resp.epoch_to_signal(name1, indices=epoch1) # do include part B pset.append(name2) newrec[name2] = resp.epoch_to_signal(name2, indices=epoch2) else: name1 = "ACTIVE_{}_{}".format(acount, 'A') name2 = "ACTIVE_{}_{}".format(acount, 'B') pset.append(name1) newrec[name1] = resp.epoch_to_signal(name1, indices=epoch1) pset.append(name2) newrec[name2] = resp.epoch_to_signal(name2, indices=epoch2) if pcount == 0: acount -= 1 else: acount += 1 state_signals.remove('each_half') state_signals.extend(pset) if 'each_half' in permute_signals: permute_signals.remove('each_half') permute_signals.extend(pset) # generate task state signals if 'pas' in state_signals: fpre = (resp.epochs['name'] == "PRE_PASSIVE") fpost = (resp.epochs['name'] == "POST_PASSIVE") INCLUDE_PRE_POST = (np.sum(fpre) > 0) & (np.sum(fpost) > 0) if INCLUDE_PRE_POST: # only include pre-passive if post-passive also exists # otherwise the regression gets screwed up newrec['pre_passive'] = resp.epoch_to_signal('PRE_PASSIVE') else: # place-holder, all zeros newrec['pre_passive'] = resp.epoch_to_signal('XXX') newrec['pre_passive'].chans = ['PRE_PASSIVE'] if 'puretone_trials' in state_signals: newrec['puretone_trials'] = resp.epoch_to_signal('PURETONE_BEHAVIOR') newrec['puretone_trials'].chans = ['puretone_trials'] if 'easy_trials' in state_signals: newrec['easy_trials'] = resp.epoch_to_signal('EASY_BEHAVIOR') newrec['easy_trials'].chans = ['easy_trials'] if 'hard_trials' in state_signals: newrec['hard_trials'] = resp.epoch_to_signal('HARD_BEHAVIOR') newrec['hard_trials'].chans = ['hard_trials'] if ('active' in state_signals) or ('far' in state_signals): newrec['active'] = resp.epoch_to_signal('ACTIVE_EXPERIMENT') newrec['active'].chans = ['active'] if (('hit_trials' in state_signals) or ('miss_trials' in state_signals) or ('far' in state_signals) or ('hit' in state_signals)): newrec['hit_trials'] = resp.epoch_to_signal('HIT_TRIAL') newrec['miss_trials'] = resp.epoch_to_signal('MISS_TRIAL') newrec['fa_trials'] = resp.epoch_to_signal('FA_TRIAL') sm_len = 180 * newrec['resp'].fs if 'far' in state_signals: a = newrec['active'].as_continuous() fa = newrec['fa_trials'].as_continuous().astype(float) #c = np.concatenate((np.zeros((1,sm_len)), np.ones((1,sm_len+1))), # axis=1) c = np.ones((1, sm_len)) / sm_len fa = convolve2d(fa, c, mode='same') fa[a] -= 0.25 # np.nanmean(fa[a]) fa[np.logical_not(a)] = 0 s = newrec['fa_trials']._modified_copy(fa) s.chans = ['far'] s.name = 'far' newrec.add_signal(s) if 'hit' in state_signals: a = newrec['active'].as_continuous() hr = newrec['hit_trials'].as_continuous().astype(float) ms = newrec['miss_trials'].as_continuous().astype(float) ht = hr - ms c = np.ones((1, sm_len)) / sm_len ht = convolve2d(ht, c, mode='same') ht[a] -= 0.1 # np.nanmean(ht[a]) ht[np.logical_not(a)] = 0 s = newrec['hit_trials']._modified_copy(ht) s.chans = ['hit'] s.name = 'hit' newrec.add_signal(s) if 'lick' in state_signals: newrec['lick'] = resp.epoch_to_signal('LICK') # pupil interactions if ('p_x_a' in state_signals): # normalize min-max p = newrec["pupil"].as_continuous() a = newrec["active"].as_continuous() newrec["p_x_a"] = newrec["pupil"]._modified_copy(p * a) newrec["p_x_a"].chans = ["p_x_a"] if ('prw' in state_signals): # add channel two of the resp to state and delete it from resp if len(rec['resp'].chans) != 2: raise ValueError("this is for pairwise fitting") else: ch2 = rec['resp'].chans[1] ch1 = rec['resp'].chans[0] newrec['prw'] = newrec['resp'].extract_channels([ch2]).rasterize() newrec['resp'] = newrec['resp'].extract_channels([ch1]).rasterize() if ('pup_x_prw' in state_signals): # interaction term between pupil and the other cell if 'prw' not in newrec.signals.keys(): raise ValueError("Must include prw alone before using interaction") else: pup = newrec['pupil']._data prw = newrec['prw']._data sig = newrec['pupil']._modified_copy(pup * prw) sig.name = 'pup_x_prw' sig.chans = ['pup_x_prw'] newrec.add_signal(sig) for i, x in enumerate(state_signals): if x in permute_signals: # kludge: fix random seed to index of state signal in list # this avoids using the same seed for each shuffled signal # but also makes shuffling reproducible newrec = concatenate_state_channel( newrec, newrec[x].shuffle_time(rand_seed=i, mask=newrec['mask']), state_signal_name=new_signalname) else: newrec = concatenate_state_channel( newrec, newrec[x], state_signal_name=new_signalname) newrec = concatenate_state_channel(newrec, newrec[x], state_signal_name=new_signalname + "_raw") return newrec
def mask_all_but_correct_references(rec, balance_rep_count=False, include_incorrect=False): """ Specialized function for removing incorrect trials from data collected using baphy during behavior. TODO: Migrate to nems_lbhb and/or make a more generic version """ newrec = rec.copy() newrec['resp'] = newrec['resp'].rasterize() if 'stim' in newrec.signals.keys(): newrec['stim'] = newrec['stim'].rasterize() resp = newrec['resp'] if balance_rep_count: epoch_regex = "^STIM_" epochs_to_extract = ep.epoch_names_matching(resp.epochs, epoch_regex) p = resp.get_epoch_indices("PASSIVE_EXPERIMENT") a = resp.get_epoch_indices("HIT_TRIAL") epoch_list = [] for s in epochs_to_extract: e = resp.get_epoch_indices(s) pe = ep.epoch_intersection(e, p) ae = ep.epoch_intersection(e, a) if len(pe) > len(ae): epoch_list.extend(ae) subset = np.round(np.linspace(0, len(pe), len(ae) + 1)).astype(int) for i in subset[:-1]: epoch_list.append(pe[i]) else: subset = np.round(np.linspace(0, len(ae), len(pe) + 1)).astype(int) for i in subset[:-1]: epoch_list.append(ae[i]) epoch_list.extend(pe) newrec = newrec.create_mask(epoch_list) elif include_incorrect: log.info('INCLUDING ALL TRIALS (CORRECT AND INCORRECT)') newrec = newrec.and_mask(['REFERENCE']) else: newrec = newrec.and_mask(['PASSIVE_EXPERIMENT', 'HIT_TRIAL']) newrec = newrec.and_mask(['REFERENCE']) # figure out if some actives should be masked out # t = ep.epoch_names_matching(resp.epochs, "^TAR_") # tm = [tt[:-2] for tt in t] # trim last digits # active_epochs = resp.get_epoch_indices("ACTIVE_EXPERIMENT") # if len(set(tm)) > 1 and len(active_epochs) > 1: # print('Multiple targets: ', tm) # files = ep.epoch_names_matching(resp.epochs, "^FILE_") # keep_files = files # e = active_epochs[1] # for i,f in enumerate(files): # fi = resp.get_epoch_indices(f) # if any(ep.epoch_contains([e], fi, 'both')): # keep_files = files[:i] # # print('Print keeping files: ', keep_files) # newrec = newrec.and_mask(keep_files) if 'state' in newrec.signals: b_states = [ 'far', 'hit', 'lick', 'puretone_trials', 'easy_trials', 'hard_trials' ] trec = newrec.copy() trec = trec.and_mask(['ACTIVE_EXPERIMENT']) st = trec['state'].as_continuous().copy() str = trec['state_raw'].as_continuous().copy() mask = trec['mask'].as_continuous()[0, :] for s in trec['state'].chans: if s in b_states: i = trec['state'].chans.index(s) m = np.nanmean(st[i, mask]) sd = np.nanstd(st[i, mask]) # print("{} {}: m={}, std={}".format(s, i, m, sd)) # print(np.sum(mask)) st[i, mask] -= m st[i, mask] /= sd str[i, mask] -= m str[i, mask] /= sd newrec['state'] = newrec['state']._modified_copy(st) newrec['state_raw'] = newrec['state_raw']._modified_copy(str) return newrec
def ev_pupil(cellid, batch, presilence=0.35): modelname = "psth.fs20.pup-st.pup.beh_stategain.3_init.st-basic" print('Finding recording for cell/batch {0}/{1}...'.format(cellid, batch)) # parse modelname kws = modelname.split("_") loader = kws[0] modelspecname = "_".join(kws[1:-1]) fitkey = kws[-1] # generate xfspec, which defines sequence of events to load data, # generate modelspec, fit data, plot results and save recording_uri = nw.generate_recording_uri(cellid, batch, loader) print(recording_uri) rec = recording.load_recording(recording_uri) pupil = rec['pupil'] pshift = np.int(pupil.fs * 0.75) #pupil = pupil._modified_copy(np.roll(pupil._data, (0, pshift))) d = pupil._data pupil = pupil._modified_copy(np.roll(d / np.nanmax(d), (0, pshift))) #pupil = pupil._modified_copy(d / np.nanmax(d)) #pupil = pupil._modified_copy(np.roll(d, (0, pshift))) trials = pupil.get_epoch_indices('TRIAL') targets = pupil.get_epoch_indices('TARGET') pt_blocks = pupil.get_epoch_indices('PURETONE_BEHAVIOR').tolist() easy_blocks = pupil.get_epoch_indices('EASY_BEHAVIOR').tolist() hard_blocks = pupil.get_epoch_indices('HARD_BEHAVIOR').tolist() passive_blocks = pupil.get_epoch_indices('PASSIVE_EXPERIMENT').tolist() behavior_blocks = pupil.get_epoch_indices('ACTIVE_EXPERIMENT') blocks = [] for p in passive_blocks: p.append('passive') blocks.append(p) for p in pt_blocks: p.append('puretone') blocks.append(p) for p in easy_blocks: p.append('easy') blocks.append(p) for p in hard_blocks: p.append('hard') blocks.append(p) blocks.sort() #print(blocks) trialbins = int(pupil.fs * 6) prebins = int(pupil.fs * presilence) ev = [] ev_prenorm = [] label = [] beh_lickrate = [] beh_lickrate_norm = [] beh_all = {} for block in blocks: k = block[-1] label.append(k) block_trials = ep.epoch_intersection(trials, np.array([block[0:2]])) tcount = block_trials.shape[0] for t in range(tcount): block_trials[t, 1] = block_trials[t, 0] + trialbins if block_trials[t, 1] > pupil.shape[1]: block_trials[t, 1] = pupil.shape[1] tev = pupil.extract_epoch(block_trials)[:, 0, :] tev0 = np.nanmean(tev[:, :prebins], axis=1) # m = tev0 > 0.3 * np.nanmax(tev0) # print(block) # print("{}-{} mean {} ({}/{} big)".format( # np.nanmin(tev0),np.nanmax(tev0),np.nanmean(tev0), # np.sum(m), len(tev0))) # tev = tev[m, :] ev.append(tev) ev_prenorm.append(tev - np.mean(tev[:, :prebins], axis=1, keepdims=True)) if k not in beh_all.keys(): beh_all[k] = np.array([]) beh_all[k] = np.append(beh_all[k], tev.ravel()) beh_lickrate.append((k, np.nanmean(ev[-1], axis=0))) beh_lickrate_norm.append((k, np.nanmean(ev_prenorm[-1], axis=0))) #print("{}: {} trials, {} bins".format(k,tev.shape[0],tev.shape[1])) perf_blocks = { 'hits': pupil.get_epoch_indices('HIT_TRIAL'), 'misses': pupil.get_epoch_indices('MISS_TRIAL'), 'fas': pupil.get_epoch_indices('FA_TRIAL') } ev = [] ev_prenorm = [] perf_lickrate = [] perf_lickrate_norm = [] perf_all = {} for k, block in perf_blocks.items(): block_trials = ep.epoch_intersection(trials, block) tcount = block_trials.shape[0] for t in range(tcount): block_trials[t, 1] = block_trials[t, 0] + trialbins if block_trials[t, 1] > pupil.shape[1]: block_trials[t, 1] = pupil.shape[1] t = pupil.extract_epoch(block_trials, allow_empty=True) if t.size: tev = t[:, 0, :] else: tev = np.ones((1, trialbins)) * np.nan perf_all[k] = t.ravel() ev.append(tev) ev_prenorm.append(tev - np.mean(tev[:, :prebins], axis=1, keepdims=True)) perf_lickrate.append((k, np.nanmean(ev[-1], axis=0))) perf_lickrate_norm.append((k, np.nanmean(ev_prenorm[-1], axis=0))) #print("{}: {} trials, {} bins".format(k,tev.shape[0],tev.shape[1])) return beh_lickrate, beh_lickrate_norm, beh_all, \ perf_lickrate, perf_lickrate_norm, perf_all
def make_state_signal(rec, state_signals=['pupil'], permute_signals=[], new_signalname='state'): """ generate state signal for stategainX models TODO: Migrate to nems_lbhb or make a more generic version """ newrec = rec.copy() resp = newrec['resp'].rasterize() # normalize mean/std of pupil trace if being used if ('pupil' in state_signals) or ('pupil_ev' in state_signals) or \ ('pupil_bs' in state_signals): # normalize min-max p = newrec["pupil"].as_continuous().copy() # p[p < np.nanmax(p)/5] = np.nanmax(p)/5 p -= np.nanmean(p) p /= np.nanstd(p) newrec["pupil"] = newrec["pupil"]._modified_copy(p) if ('pupil_psd') in state_signals: pup = newrec['pupil'].as_continuous().copy() fs = newrec['pupil'].fs # get spectrogram of pupil nperseg = int(60 * fs) noverlap = nperseg - 1 f, time, Sxx = ss.spectrogram(pup.squeeze(), fs=fs, nperseg=nperseg, noverlap=noverlap) max_chan = 4 #(np.abs(f - 0.1)).argmin() # Keep only first five channels of spectrogram #f = interpolate.interp1d(np.arange(0, Sxx.shape[1]), Sxx[:max_chan, :], axis=1) #newspec = f(np.linspace(0, Sxx.shape[-1]-1, pup.shape[-1])) pad1 = np.ones((max_chan, int(nperseg / 2))) * Sxx[:max_chan, [0]] pad2 = np.ones((max_chan, int(nperseg / 2 - 1))) * Sxx[:max_chan, [-1]] newspec = np.concatenate((pad1, Sxx[:max_chan, :], pad2), axis=1) # = np.concatenate((Sxx[:max_chan, :], np.tile(Sxx[:max_chan,-1][:, np.newaxis], [1, noverlap])), axis=1) newspec -= np.nanmean(newspec, axis=1, keepdims=True) newspec /= np.nanstd(newspec, axis=1, keepdims=True) spec_signal = newrec['pupil']._modified_copy(newspec) spec_signal.name = 'pupil_psd' chan_names = [] for chan in range(0, newspec.shape[0]): chan_names.append('puppsd{0}'.format(chan)) spec_signal.chans = chan_names newrec.add_signal(spec_signal) if ('pupil_ev' in state_signals) or ('pupil_bs' in state_signals): # generate separate pupil baseline and evoked signals prestimsilence = newrec["pupil"].extract_epoch('PreStimSilence') spont_bins = prestimsilence.shape[2] pupil_trial = newrec["pupil"].extract_epoch('TRIAL') pupil_bs = np.zeros(pupil_trial.shape) for ii in range(pupil_trial.shape[0]): pupil_bs[ii, :, :] = np.mean(pupil_trial[ii, :, :spont_bins]) pupil_ev = pupil_trial - pupil_bs newrec['pupil_ev'] = newrec["pupil"].replace_epoch('TRIAL', pupil_ev) newrec['pupil_ev'].chans = ['pupil_ev'] newrec['pupil_bs'] = newrec["pupil"].replace_epoch('TRIAL', pupil_bs) newrec['pupil_bs'].chans = ['pupil_bs'] if ('each_passive' in state_signals): file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_") pset = [] for f in file_epochs: epoch_indices = ep.epoch_intersection( resp.get_epoch_indices(f), resp.get_epoch_indices('PASSIVE_EXPERIMENT')) if epoch_indices.size: pset.append(f) newrec[f] = resp.epoch_to_signal(f) state_signals.remove('each_passive') state_signals.extend(pset) if 'each_passive' in permute_signals: permute_signals.remove('each_passive') permute_signals.extend(pset) # generate task state signals fpre = (resp.epochs['name'] == "PRE_PASSIVE") fpost = (resp.epochs['name'] == "POST_PASSIVE") INCLUDE_PRE_POST = (np.sum(fpre) > 0) & (np.sum(fpost) > 0) if INCLUDE_PRE_POST: # only include pre-passive if post-passive also exists # otherwise the regression gets screwed up newrec['pre_passive'] = resp.epoch_to_signal('PRE_PASSIVE') else: # place-holder, all zeros newrec['pre_passive'] = resp.epoch_to_signal('XXX') newrec['pre_passive'].chans = ['PRE_PASSIVE'] newrec['hit_trials'] = resp.epoch_to_signal('HIT_TRIAL') newrec['miss_trials'] = resp.epoch_to_signal('MISS_TRIAL') newrec['fa_trials'] = resp.epoch_to_signal('FA_TRIAL') newrec['puretone_trials'] = resp.epoch_to_signal('PURETONE_BEHAVIOR') newrec['puretone_trials'].chans = ['puretone_trials'] newrec['easy_trials'] = resp.epoch_to_signal('EASY_BEHAVIOR') newrec['easy_trials'].chans = ['easy_trials'] newrec['hard_trials'] = resp.epoch_to_signal('HARD_BEHAVIOR') newrec['hard_trials'].chans = ['hard_trials'] newrec['active'] = resp.epoch_to_signal('ACTIVE_EXPERIMENT') newrec['active'].chans = ['active'] if 'lick' in state_signals: newrec['lick'] = resp.epoch_to_signal('LICK') # pupil interactions if ('pupil' in state_signals): # normalize min-max p = newrec["pupil"].as_continuous() a = newrec["active"].as_continuous() newrec["p_x_a"] = newrec["pupil"]._modified_copy(p * a) newrec["p_x_a"].chans = ["p_x_a"] for i, x in enumerate(state_signals): if x in permute_signals: # kludge: fix random seed to index of state signal in list # this avoids using the same seed for each shuffled signal # but also makes shuffling reproducible newrec = concatenate_state_channel( newrec, newrec[x].shuffle_time(rand_seed=i), state_signal_name=new_signalname) else: newrec = concatenate_state_channel( newrec, newrec[x], state_signal_name=new_signalname) return newrec
def _model_step_plot(cellid, batch, modelnames, factors, state_colors=None): """ state_colors : N x 2 list color spec for high/low lines in each of the N states """ global line_colors global fill_colors modelname_p0b0, modelname_p0b, modelname_pb0, modelname_pb = \ modelnames factor0, factor1, factor2 = factors xf_p0b0, ctx_p0b0 = xhelp.load_model_xform(cellid, batch, modelname_p0b0, eval_model=False) # ctx_p0b0, l = xforms.evaluate(xf_p0b0, ctx_p0b0, stop=-2) ctx_p0b0, l = xforms.evaluate(xf_p0b0, ctx_p0b0, start=0, stop=-2) xf_p0b, ctx_p0b = xhelp.load_model_xform(cellid, batch, modelname_p0b, eval_model=False) ctx_p0b, l = xforms.evaluate(xf_p0b, ctx_p0b, start=0, stop=-2) xf_pb0, ctx_pb0 = xhelp.load_model_xform(cellid, batch, modelname_pb0, eval_model=False) #ctx_pb0['rec'] = ctx_p0b0['rec'].copy() ctx_pb0, l = xforms.evaluate(xf_pb0, ctx_pb0, start=0, stop=-2) xf_pb, ctx_pb = xhelp.load_model_xform(cellid, batch, modelname_pb, eval_model=False) #ctx_pb['rec'] = ctx_p0b0['rec'].copy() ctx_pb, l = xforms.evaluate(xf_pb, ctx_pb, start=0, stop=-2) # organize predictions by different models val = ctx_pb['val'][0].copy() # val['pred_p0b0'] = ctx_p0b0['val'][0]['pred'].copy() val['pred_p0b'] = ctx_p0b['val'][0]['pred'].copy() val['pred_pb0'] = ctx_pb0['val'][0]['pred'].copy() state_var_list = val['state'].chans pred_mod = np.zeros([len(state_var_list), 2]) pred_mod_full = np.zeros([len(state_var_list), 2]) resp_mod_full = np.zeros([len(state_var_list), 1]) state_std = np.nanstd(val['state'].as_continuous(), axis=1, keepdims=True) for i, var in enumerate(state_var_list): if state_std[i]: # actual response modulation index for each state var resp_mod_full[i] = state_mod_index(val, epoch='REFERENCE', psth_name='resp', state_chan=var) mod2_p0b = state_mod_index(val, epoch='REFERENCE', psth_name='pred_p0b', state_chan=var) mod2_pb0 = state_mod_index(val, epoch='REFERENCE', psth_name='pred_pb0', state_chan=var) mod2_pb = state_mod_index(val, epoch='REFERENCE', psth_name='pred', state_chan=var) pred_mod[i] = np.array([mod2_pb - mod2_p0b, mod2_pb - mod2_pb0]) pred_mod_full[i] = np.array([mod2_pb0, mod2_p0b]) pred_mod_norm = pred_mod / (state_std + (state_std == 0).astype(float)) pred_mod_full_norm = pred_mod_full / (state_std + (state_std == 0).astype(float)) if 'each_passive' in factors: psth_names_ctl = ["pred_p0b"] factors.remove('each_passive') for v in state_var_list: if v.startswith('FILE_'): factors.append(v) psth_names_ctl.append("pred_pb0") else: psth_names_ctl = ["pred_p0b", "pred_pb0"] col_count = len(factors) - 1 if state_colors is None: state_colors = [[None, None]] * col_count fh = plt.figure(figsize=(8, 8)) ax = plt.subplot(4, 1, 1) nplt.state_vars_timeseries(val, ctx_pb['modelspecs'][0], state_colors=[s[1] for s in state_colors]) ax.set_title("{}/{} - {}".format(cellid, batch, modelname_pb)) ax.set_ylabel("{} r={:.3f}".format( factor0, ctx_p0b0['modelspecs'][0][0]['meta']['r_test'][0])) nplt.ax_remove_box(ax) for i, var in enumerate(factors[1:]): if var.startswith('FILE_'): varlbl = var[5:] else: varlbl = var ax = plt.subplot(4, col_count, col_count + i + 1) nplt.state_var_psth_from_epoch(val, epoch="REFERENCE", psth_name="resp", psth_name2=psth_names_ctl[i], state_chan=var, ax=ax, colors=state_colors[i]) if i == 0: ax.set_ylabel("Control model") ax.set_title("{} ctl r={:.3f}".format( varlbl.lower(), ctx_p0b['modelspecs'][0][0]['meta']['r_test'][0]), fontsize=6) else: ax.yaxis.label.set_visible(False) ax.set_title("{} ctl r={:.3f}".format( varlbl.lower(), ctx_pb0['modelspecs'][0][0]['meta']['r_test'][0]), fontsize=6) if ax.legend_: ax.legend_.remove() ax.xaxis.label.set_visible(False) nplt.ax_remove_box(ax) ax = plt.subplot(4, col_count, col_count * 2 + i + 1) nplt.state_var_psth_from_epoch(val, epoch="REFERENCE", psth_name="resp", psth_name2="pred", state_chan=var, ax=ax, colors=state_colors[i]) if i == 0: ax.set_ylabel("Full Model") else: ax.yaxis.label.set_visible(False) if ax.legend_: ax.legend_.remove() if psth_names_ctl[i] == "pred_p0b": j = 0 else: j = 1 ax.set_title("r={:.3f} rawmod={:.3f} umod={:.3f}".format( ctx_pb['modelspecs'][0][0]['meta']['r_test'][0], pred_mod_full_norm[i + 1][j], pred_mod_norm[i + 1][j]), fontsize=6) if var == 'active': ax.legend(('pas', 'act')) elif var == 'pupil': ax.legend(('small', 'large')) elif var == 'PRE_PASSIVE': ax.legend(('act+post', 'pre')) elif var.startswith('FILE_'): ax.legend(('this', 'others')) nplt.ax_remove_box(ax) # EXTRA PANELS # figure out some basic aspects of tuning/selectivity for target vs. # reference: r = ctx_pb['rec']['resp'] e = r.epochs fs = r.fs passive_epochs = r.get_epoch_indices("PASSIVE_EXPERIMENT") tar_names = ep.epoch_names_matching(e, "^TAR_") tar_resp = {} for tarname in tar_names: t = r.get_epoch_indices(tarname) t = ep.epoch_intersection(t, passive_epochs) tar_resp[tarname] = r.extract_epoch(t) * fs # only plot tar responses with max SNR or probe SNR keys = [] for k in list(tar_resp.keys()): if k.endswith('0') | k.endswith('2'): keys.append(k) keys.sort() # assume the reference with most reps is the one overlapping the target groups = ep.group_epochs_by_occurrence_counts(e, '^STIM_') l = np.array(list(groups.keys())) hi = np.max(l) ref_name = groups[hi][0] t = r.get_epoch_indices(ref_name) t = ep.epoch_intersection(t, passive_epochs) ref_resp = r.extract_epoch(t) * fs t = r.get_epoch_indices('REFERENCE') t = ep.epoch_intersection(t, passive_epochs) all_ref_resp = r.extract_epoch(t) * fs prestimsilence = r.get_epoch_indices('PreStimSilence') prebins = prestimsilence[0, 1] - prestimsilence[0, 0] poststimsilence = r.get_epoch_indices('PostStimSilence') postbins = poststimsilence[0, 1] - poststimsilence[0, 0] durbins = ref_resp.shape[-1] - prebins spont = np.nanmean(all_ref_resp[:, 0, :prebins]) ref_mean = np.nanmean(ref_resp[:, 0, prebins:durbins]) - spont all_ref_mean = np.nanmean(all_ref_resp[:, 0, prebins:durbins]) - spont #print(spont) #print(np.nanmean(ref_resp[:,0,prebins:-postbins])) ax1 = plt.subplot(4, 2, 7) ref_psth = [ np.nanmean(ref_resp[:, 0, :], axis=0), np.nanmean(all_ref_resp[:, 0, :], axis=0) ] ll = [ "{} {:.1f}".format(ref_name, ref_mean), "all refs {:.1f}".format(all_ref_mean) ] nplt.timeseries_from_vectors(ref_psth, fs=fs, legend=ll, ax=ax1, time_offset=prebins / fs) ax2 = plt.subplot(4, 2, 8) ll = [] tar_mean = np.zeros(np.max([2, len(keys)])) * np.nan tar_psth = [] for ii, k in enumerate(keys): tar_psth.append(np.nanmean(tar_resp[k][:, 0, :], axis=0)) tar_mean[ii] = np.nanmean(tar_resp[k][:, 0, prebins:durbins]) - spont ll.append("{} {:.1f}".format(k, tar_mean[ii])) nplt.timeseries_from_vectors(tar_psth, fs=fs, legend=ll, ax=ax2, time_offset=prebins / fs) # plt.legend(ll, fontsize=6) ymin = np.min([ax1.get_ylim()[0], ax2.get_ylim()[0]]) ymax = np.max([ax1.get_ylim()[1], ax2.get_ylim()[1]]) ax1.set_ylim([ymin, ymax]) ax2.set_ylim([ymin, ymax]) nplt.ax_remove_box(ax1) nplt.ax_remove_box(ax2) plt.tight_layout() stats = { 'cellid': cellid, 'batch': batch, 'modelnames': modelnames, 'state_vars': state_var_list, 'factors': factors, 'r_test': np.array([ ctx_p0b0['modelspecs'][0][0]['meta']['r_test'][0], ctx_p0b['modelspecs'][0][0]['meta']['r_test'][0], ctx_pb0['modelspecs'][0][0]['meta']['r_test'][0], ctx_pb['modelspecs'][0][0]['meta']['r_test'][0] ]), 'se_test': np.array([ ctx_p0b0['modelspecs'][0][0]['meta']['se_test'][0], ctx_p0b['modelspecs'][0][0]['meta']['se_test'][0], ctx_pb0['modelspecs'][0][0]['meta']['se_test'][0], ctx_pb['modelspecs'][0][0]['meta']['se_test'][0] ]), 'r_floor': np.array([ ctx_p0b0['modelspecs'][0][0]['meta']['r_floor'][0], ctx_p0b['modelspecs'][0][0]['meta']['r_floor'][0], ctx_pb0['modelspecs'][0][0]['meta']['r_floor'][0], ctx_pb['modelspecs'][0][0]['meta']['r_floor'][0] ]), 'pred_mod': pred_mod.T, 'pred_mod_full': pred_mod_full.T, 'pred_mod_norm': pred_mod_norm.T, 'pred_mod_full_norm': pred_mod_full_norm.T, 'g': np.array([ ctx_p0b0['modelspecs'][0][0]['phi']['g'], ctx_p0b['modelspecs'][0][0]['phi']['g'], ctx_pb0['modelspecs'][0][0]['phi']['g'], ctx_pb['modelspecs'][0][0]['phi']['g'] ]), 'b': np.array([ ctx_p0b0['modelspecs'][0][0]['phi']['d'], ctx_p0b['modelspecs'][0][0]['phi']['d'], ctx_pb0['modelspecs'][0][0]['phi']['d'], ctx_pb['modelspecs'][0][0]['phi']['d'] ]), 'ref_all_resp': all_ref_mean, 'ref_common_resp': ref_mean, 'tar_max_resp': tar_mean[0], 'tar_probe_resp': tar_mean[1] } return fh, stats
xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) if browse_results: ex = gui.browse_xform_fit(ctx, xfspec) else: ctx['modelspec'].quickplot() r=ctx['rec']['resp'] dual = r.get_epoch_indices('dual') rrnd9 = r.get_epoch_indices('Stim , 10 , Reference') rnd9 = r.get_epoch_indices('Stim , 10 , Target') rep9 = r.get_epoch_indices('Stim , 10 , TargetRep') rep9 = r.get_epoch_indices('target_1_repeating_dual') a = epoch.epoch_intersection(rep9, dual) raster_rnd = r.extract_epoch('target_0') raster_rep = r.extract_epoch('target_0_repeating_dual') plt.figure() plt.subplot(2,2,1) i, j = np.where(raster_rnd[:,0,:]) i += 1 plt.plot(j, i,' k.') plt.title('rand') plt.subplot(2,2,2) i, j = np.where(raster_rep[:,0,:]) i += 1 plt.plot(j, i,' k.') plt.title('rep')