def batch_phase_by_cond(cell_num, disp, cons=[], sfs=[], dir=-1, dp=dataPath, expName=expName): ''' must specify dispersion (one value) if cons = [], then get/plot all valid contrasts for the given dispersion if sfs = [], then get/plot all valid contrasts for the given dispersion ''' dataList = hf.np_smart_load(str(dp + expName)) fileName = dataList['unitName'][cell_num - 1] cellStruct = hf.np_smart_load(str(dp + fileName + '_sfm.npy')) data = cellStruct['sfm']['exp']['trial'] # prepare the valid stim parameters by condition in case needed resp, stimVals, val_con_by_disp, validByStimVal, mdRsp = hf.tabulate_responses( data) # gather the sf indices in case we need - this is a dictionary whose keys are the valid sf indices valSf = validByStimVal[2] if cons == []: # then get all valid cons for this dispersion cons = val_con_by_disp[disp] if sfs == []: # then get all valid sfs for this dispersion sfs = list(valSf.keys()) for c in cons: for s in sfs: print('analyzing cell %d, dispersion %d, contrast %d, sf %d\n' % (cell_num, disp, c, s)) phase_by_cond(cell_num, data, disp, c, s, dir=dir)
def descr_loss(params, data, family, contrast, loss_type = 1): # set constants epsilon = 1e-4; trial = data['sfm']['exp']['trial']; respMetrics, stimVals, val_con_by_disp, validByStimVal, ignore = hfunc.tabulate_responses(data); # get indices for trials we want to look at valid_disp = validByStimVal[0]; valid_con = validByStimVal[1]; # family, contrast are in absolute terms (i.e. we pass over non-valid ones, so we can just index normally) curr_con = valid_con[contrast]; curr_disp = valid_disp[family]; indices = np.where(curr_con & curr_disp); obs_count = trial['spikeCount'][indices]; pred_rate = flexible_Gauss(params, trial['sf'][0][indices]); stim_dur = trial['duration'][indices]; if loss_type == 1: curr_loss = np.square(pred_rate * stim_dur - obs_count); NLL = sum(curr_loss); elif loss_type == 2: curr_loss = np.square(np.sqrt(pred_rate * stim_dur) - np.sqrt(obs_count)); NLL = sum(curr_loss); elif loss_type == 3: # poisson model of spiking poiss = poisson.pmf(obs_count, pred_rate * stim_dur); ps = np.sum(poiss == 0); if ps > 0: poiss = np.maximum(poiss, 1e-6); # anything, just so we avoid log(0) NLL = sum(-np.log(poiss)); return NLL;
hf.organize_resp(mr, expData, expInd, respsAsRate=False) for mr in modResps ] oriModResps = [org[0] for org in orgs] # only non-empty if expInd = 1 conModResps = [org[1] for org in orgs] # only non-empty if expInd = 1 sfmixModResps = [org[2] for org in orgs] allSfMixs = [org[3] for org in orgs] modLows = [np.nanmin(resp, axis=3) for resp in allSfMixs] modHighs = [np.nanmax(resp, axis=3) for resp in allSfMixs] modAvgs = [np.nanmean(resp, axis=3) for resp in allSfMixs] modSponRates = [fit[6] for fit in modFits] # more tabulation - stim vals, organize measured responses _, stimVals, val_con_by_disp, validByStimVal, _ = hf.tabulate_responses( expData, expInd) if rvcAdj == 1: rvcFlag = '' rvcFits = hf.get_rvc_fits(data_loc, expInd, cellNum, rvcName=rvcBase, rvcMod=rvcMod) asRates = True else: rvcFlag = '_f0' rvcFits = hf.get_rvc_fits(data_loc, expInd, cellNum, rvcName='None') asRates = False # TODO: maybe should make asRates=True here, too, right? Since get_adjusted_spikerate always gives us a rate? Should check... # rvcMod=-1 tells the function call to treat rvcName as the fits, already (we loaded above!) spikes_rate = hf.get_adjusted_spikerate(expData['sfm']['exp']['trial'],
def phase_advance_fit(cell_num, data_loc=dataPath, phAdvName=phAdvName, to_save=1, disp=0, dir=-1): ''' Given the FFT-derived response amplitude and phase, determine the response phase relative to the stimulus by taking into account the stimulus phase. Then, make a simple linear model fit (line + constant offset) of the response phase as a function of response amplitude. vSAVES loss/optimized parameters/and phase advance (if default "to_save" value is kept) RETURNS phAdv_model, all_opts Do ONLY for single gratings ''' dataList = hf.np_smart_load(data_loc + 'dataList.npy') cellStruct = hf.np_smart_load(data_loc + dataList['unitName'][cell_num - 1] + '_sfm.npy') data = cellStruct['sfm']['exp']['trial'] phAdvName = hf.fit_name(phAdvName, dir) # first, get the set of stimulus values: _, stimVals, valConByDisp, _, _ = hf.tabulate_responses(data, expInd=expInd) allCons = stimVals[1] allSfs = stimVals[2] # for all con/sf values for this dispersion, compute the mean amplitude/phase per condition allAmp, allPhi, allTf, _, _ = hf.get_all_fft(data, disp, dir=dir) # now, compute the phase advance conInds = valConByDisp[disp] conVals = allCons[conInds] nConds = len(allAmp) # this is how many conditions are present for this dispersion # recall that nConds = nCons * nSfs allCons = [conVals] * nConds # repeats list and nests phAdv_model, all_opts, all_phAdv, all_loss = hf.phase_advance( allAmp, allPhi, conVals, allTf) if os.path.isfile(data_loc + phAdvName): phFits = hf.np_smart_load(data_loc + phAdvName) else: phFits = dict() # update stuff - load again in case some other run has saved/made changes if os.path.isfile(data_loc + phAdvName): print('reloading phAdvFits...') phFits = hf.np_smart_load(data_loc + phAdvName) if cell_num - 1 not in phFits: phFits[cell_num - 1] = dict() phFits[cell_num - 1]['loss'] = all_loss phFits[cell_num - 1]['params'] = all_opts phFits[cell_num - 1]['phAdv'] = all_phAdv if to_save: np.save(data_loc + phAdvName, phFits) print('saving phase advance fit for cell ' + str(cell_num)) return phAdv_model, all_opts
def fit_descr_DoG(cell_num, data_loc=dataPath, n_repeats=1000, loss_type=3, DoGmodel=1, disp=0, rvcName=rvcName, dir=-1, gain_reg=0, fLname=dogName): nParam = 4 # load cell information dataList = hf.np_smart_load(data_loc + 'dataList.npy') assert dataList != [], "data file not found!" if loss_type == 1: loss_str = '_poiss' elif loss_type == 2: loss_str = '_sqrt' elif loss_type == 3: loss_str = '_sach' elif loss_type == 4: loss_str = '_varExpl' if DoGmodel == 1: mod_str = '_sach' elif DoGmodel == 2: mod_str = '_tony' fLname = str(data_loc + fLname + loss_str + mod_str + '.npy') if os.path.isfile(fLname): descrFits = hf.np_smart_load(fLname) else: descrFits = dict() cellStruct = hf.np_smart_load(data_loc + dataList['unitName'][cell_num - 1] + '_sfm.npy') data = cellStruct['sfm']['exp']['trial'] rvcNameFinal = hf.phase_fit_name(rvcName, dir) rvcFits = hf.np_smart_load(data_loc + rvcNameFinal) adjResps = rvcFits[cell_num - 1][disp]['adjMeans'] adjSem = rvcFits[cell_num - 1][disp]['adjSem'] if 'adjByTr' in rvcFits[cell_num - 1][disp]: adjByTr = rvcFits[cell_num - 1][disp]['adjByTr'] if disp == 1: adjResps = [np.sum(x, 1) if x else [] for x in adjResps] if adjByTr: adjByTr = [np.sum(x, 1) if x else [] for x in adjByTr] adjResps = np.array(adjResps) # indexing multiple SFs will work only if we convert to numpy array first adjSem = np.array([np.array(x) for x in adjSem]) # make each inner list an array, and the whole thing an array print('Doing the work, now') # first, get the set of stimulus values: resps, stimVals, valConByDisp, _, _ = hf.tabulate_responses(data, expInd=expInd) # LGN is expInd=3 all_disps = stimVals[0] all_cons = stimVals[1] all_sfs = stimVals[2] nDisps = len(all_disps) nCons = len(all_cons) if cell_num - 1 in descrFits: bestNLL = descrFits[cell_num - 1]['NLL'] currParams = descrFits[cell_num - 1]['params'] varExpl = descrFits[cell_num - 1]['varExpl'] prefSf = descrFits[cell_num - 1]['prefSf'] charFreq = descrFits[cell_num - 1]['charFreq'] else: # set values to NaN... bestNLL = np.ones((nDisps, nCons)) * np.nan currParams = np.ones((nDisps, nCons, nParam)) * np.nan varExpl = np.ones((nDisps, nCons)) * np.nan prefSf = np.ones((nDisps, nCons)) * np.nan charFreq = np.ones((nDisps, nCons)) * np.nan # set bounds if DoGmodel == 1: bound_gainCent = (1e-3, None) bound_radiusCent = (1e-3, None) bound_gainSurr = (1e-3, None) bound_radiusSurr = (1e-3, None) allBounds = (bound_gainCent, bound_radiusCent, bound_gainSurr, bound_radiusSurr) elif DoGmodel == 2: bound_gainCent = (1e-3, None) bound_gainFracSurr = (1e-2, 1) bound_freqCent = (1e-3, None) bound_freqFracSurr = (1e-2, 1) allBounds = (bound_gainCent, bound_freqCent, bound_gainFracSurr, bound_freqFracSurr) for d in range( 1 ): # should be nDisps - just setting to 1 for now (i.e. fitting single gratings and mixtures separately) for con in range(nCons): if con not in valConByDisp[disp]: continue valSfInds = hf.get_valid_sfs(data, disp, con, expInd) valSfVals = all_sfs[valSfInds] print('.') # adjResponses (f1) in the rvcFits are separate by sf, values within contrast - so to get all responses for a given SF, # access all sfs and get the specific contrast response respConInd = np.where(np.asarray(valConByDisp[disp]) == con)[0] pdb.set_trace() ### interlude... spks = hf.get_spikes(data, rvcFits=rvcFits[cell_num - 1], expInd=expInd) _, _, mnResp, alResp = hf.organize_resp(spks, data, expInd) ### resps = flatten([x[respConInd] for x in adjResps[valSfInds]]) resps_sem = [x[respConInd] for x in adjSem[valSfInds]] if isinstance(resps_sem[0], np.ndarray): # i.e. if it's still array of arrays... resps_sem = flatten(resps_sem) #resps_sem = None; maxResp = np.max(resps) freqAtMaxResp = all_sfs[np.argmax(resps)] for n_try in range(n_repeats): # pick initial params if DoGmodel == 1: init_gainCent = hf.random_in_range( (maxResp, 5 * maxResp))[0] init_radiusCent = hf.random_in_range((0.05, 2))[0] init_gainSurr = init_gainCent * hf.random_in_range( (0.1, 0.8))[0] init_radiusSurr = hf.random_in_range((0.5, 4))[0] init_params = [ init_gainCent, init_radiusCent, init_gainSurr, init_radiusSurr ] elif DoGmodel == 2: init_gainCent = maxResp * hf.random_in_range((0.9, 1.2))[0] init_freqCent = np.maximum( all_sfs[2], freqAtMaxResp * hf.random_in_range((1.2, 1.5))[0]) # don't pick all_sfs[0] -- that's zero (we're avoiding that) init_gainFracSurr = hf.random_in_range((0.7, 1))[0] init_freqFracSurr = hf.random_in_range((.25, .35))[0] init_params = [ init_gainCent, init_freqCent, init_gainFracSurr, init_freqFracSurr ] # choose optimization method if np.mod(n_try, 2) == 0: methodStr = 'L-BFGS-B' else: methodStr = 'TNC' obj = lambda params: DoG_loss(params, resps, valSfVals, resps_std=resps_sem, loss_type=loss_type, DoGmodel=DoGmodel, dir=dir, gain_reg=gain_reg) wax = opt.minimize(obj, init_params, method=methodStr, bounds=allBounds) # compare NLL = wax['fun'] params = wax['x'] if np.isnan(bestNLL[disp, con]) or NLL < bestNLL[disp, con]: bestNLL[disp, con] = NLL currParams[disp, con, :] = params varExpl[disp, con] = hf.var_explained(resps, params, valSfVals) prefSf[disp, con] = hf.dog_prefSf(params, DoGmodel, valSfVals) charFreq[disp, con] = hf.dog_charFreq(params, DoGmodel) # update stuff - load again in case some other run has saved/made changes if os.path.isfile(fLname): print('reloading descrFits...') descrFits = hf.np_smart_load(fLname) if cell_num - 1 not in descrFits: descrFits[cell_num - 1] = dict() descrFits[cell_num - 1]['NLL'] = bestNLL descrFits[cell_num - 1]['params'] = currParams descrFits[cell_num - 1]['varExpl'] = varExpl descrFits[cell_num - 1]['prefSf'] = prefSf descrFits[cell_num - 1]['charFreq'] = charFreq descrFits[cell_num - 1]['gainRegFactor'] = gain_reg np.save(fLname, descrFits) print('saving for cell ' + str(cell_num))
def rvc_adjusted_fit(cell_num, data_loc=dataPath, rvcName=rvcName, to_save=1, disp=0, dir=-1): ''' Piggy-backing off of phase_advance_fit above, get prepare to project the responses onto the proper phase to get the correct amplitude Then, with the corrected response amplitudes, fit the RVC model ''' dataList = hf.np_smart_load(data_loc + 'dataList.npy') cellStruct = hf.np_smart_load(data_loc + dataList['unitName'][cell_num - 1] + '_sfm.npy') data = cellStruct['sfm']['exp']['trial'] rvcNameFinal = hf.fit_name(rvcName, dir) # first, get the set of stimulus values: _, stimVals, valConByDisp, _, _ = hf.tabulate_responses(data, expInd=expInd) allCons = stimVals[1] allSfs = stimVals[2] valCons = allCons[valConByDisp[disp]] # calling phase_advance fit, use the phAdv_model and optimized paramters to compute the true response amplitude # given the measured/observed amplitude and phase of the response # NOTE: We always call phase_advance_fit with disp=0 (default), since we don't make a fit # for the mixtrue stimuli - instead, we use the fits made on single gratings to project the # individual-component-in-mixture responses phAdv_model, all_opts = phase_advance_fit(cell_num, dir=dir, to_save=0) # don't save allAmp, allPhi, _, allCompCon, allCompSf = hf.get_all_fft(data, disp, dir=dir, all_trials=1) # get just the mean amp/phi and put into convenient lists allAmpMeans = [[x[0] for x in sf] for sf in allAmp] # mean is in the first element; do that for each [mean, std] pair in each list (split by sf) allAmpTrials = [[x[2] for x in sf] for sf in allAmp] # trial-by-trial is third element allPhiMeans = [[x[0] for x in sf] for sf in allPhi] # mean is in the first element; do that for each [mean, var] pair in each list (split by sf) allPhiTrials = [[x[2] for x in sf] for sf in allPhi] # trial-by-trial is third element adjMeans = hf.project_resp(allAmpMeans, allPhiMeans, phAdv_model, all_opts, disp, allCompSf, allSfs) adjByTrial = hf.project_resp(allAmpTrials, allPhiTrials, phAdv_model, all_opts, disp, allCompSf, allSfs) consRepeat = [valCons] * len(adjMeans) if disp == 1: # then we need to sum component responses and get overall std measure (we'll fit to sum, not indiv. comp responses!) adjSumResp = [np.sum(x, 1) if x else [] for x in adjMeans] adjSemTr = [[sem(np.sum(hf.switch_inner_outer(x), 1)) for x in y] for y in adjByTrial] adjSemCompTr = [[sem(hf.switch_inner_outer(x)) for x in y] for y in adjByTrial] rvc_model, all_opts, all_conGains, all_loss = hf.rvc_fit( adjSumResp, consRepeat, adjSemTr) elif disp == 0: adjSemTr = [[sem(x) for x in y] for y in adjByTrial] adjSemCompTr = adjSemTr # for single gratings, there is only one component! rvc_model, all_opts, all_conGains, all_loss = hf.rvc_fit( adjMeans, consRepeat, adjSemTr) if os.path.isfile(data_loc + rvcNameFinal): rvcFits = hf.np_smart_load(data_loc + rvcNameFinal) else: rvcFits = dict() # update stuff - load again in case some other run has saved/made changes if os.path.isfile(data_loc + rvcNameFinal): print('reloading rvcFits...') rvcFits = hf.np_smart_load(data_loc + rvcNameFinal) if cell_num - 1 not in rvcFits: rvcFits[cell_num - 1] = dict() rvcFits[cell_num - 1][disp] = dict() else: # cell_num-1 is a key in rvcFits if disp not in rvcFits[cell_num - 1]: rvcFits[cell_num - 1][disp] = dict() rvcFits[cell_num - 1][disp]['loss'] = all_loss rvcFits[cell_num - 1][disp]['params'] = all_opts rvcFits[cell_num - 1][disp]['conGain'] = all_conGains rvcFits[cell_num - 1][disp]['adjMeans'] = adjMeans rvcFits[cell_num - 1][disp]['adjByTr'] = adjByTrial rvcFits[cell_num - 1][disp]['adjSem'] = adjSemTr rvcFits[cell_num - 1][disp]['adjSemComp'] = adjSemCompTr if to_save: np.save(data_loc + rvcNameFinal, rvcFits) print('saving rvc fit for cell ' + str(cell_num)) return rvc_model, all_opts, all_conGains, adjMeans
def plot_phase_advance(which_cell, disp, sv_loc=save_loc, dir=-1, dp=dataPath, expName=expName, phAdvStr=phAdvName, rvcStr=rvcName, date_suffix=''): ''' RVC, resp-X-phase, phase advance model split by SF within each cell/dispersion condition 1. response-versus-contrast; shows original and adjusted response 2. polar plot of response amplitude and phase with phase advance model fit 3. response amplitude (x) vs. phase (y) with the linear phase advance model fit ''' # basics dataList = hf.np_smart_load(str(dp + expName)) cellName = dataList['unitName'][which_cell - 1] expInd = hf.get_exp_ind(dp, cellName)[0] cellStruct = hf.np_smart_load(str(dp + cellName + '_sfm.npy')) rvcFits = hf.np_smart_load(str(dp + hf.phase_fit_name(rvcStr, dir))) rvcFits = rvcFits[which_cell - 1] rvc_model = hf.get_rvc_model() phAdvFits = hf.np_smart_load(str(dp + hf.phase_fit_name(phAdvStr, dir))) phAdvFits = phAdvFits[which_cell - 1] phAdv_model = hf.get_phAdv_model() save_base = sv_loc + 'phasePlots_%s/' % date_suffix # gather/compute everything we need data = cellStruct['sfm']['exp']['trial'] _, stimVals, val_con_by_disp, validByStimVal, _ = hf.tabulate_responses( data, expInd) valDisp = validByStimVal[0] valCon = validByStimVal[1] valSf = validByStimVal[2] allDisps = stimVals[0] allCons = stimVals[1] allSfs = stimVals[2] con_inds = val_con_by_disp[disp] # now get ready to plot fPhaseAdv = [] # we will summarize for all spatial frequencies for a given cell! for j in range(len(allSfs)): # first, get the responses and phases that we need: amps = [] phis = [] sf = j for i in con_inds: val_trials = np.where(valDisp[disp] & valCon[i] & valSf[sf]) # get the phase of the response relative to the stimulus (ph_rel_stim) ph_rel_stim, stim_ph, resp_ph, all_tf = hf.get_true_phase( data, val_trials, expInd, dir=dir) phis.append(ph_rel_stim) # get the relevant amplitudes (i.e. the amplitudes at the stimulus TF) stimDur = hf.get_exp_params(expInd).stimDur psth_val, _ = hf.make_psth(data['spikeTimes'][val_trials], stimDur=stimDur) _, rel_amp, _ = hf.spike_fft(psth_val, all_tf, stimDur=stimDur) amps.append(rel_amp) r, th, _, _ = hf.polar_vec_mean(amps, phis) # mean amp/phase (outputs 1/2); std/var for amp/phase (outputs 3/4) # get the models/fits that we need: con_values = allCons[con_inds] ## phase advance opt_params_phAdv = phAdvFits['params'][sf] ph_adv = phAdvFits['phAdv'][sf] ## rvc opt_params_rvc = rvcFits[disp]['params'][sf] con_gain = rvcFits[disp]['conGain'][sf] adj_means = rvcFits[disp]['adjMeans'][sf] if disp == 1: # then sum adj_means (saved by component) adj_means = [np.sum(x, 1) if x else [] for x in adj_means] # (Above) remember that we have to project the amp/phase vector onto the "correct" phase for estimate of noiseless response ## now get ready to plot! f, ax = plt.subplots(2, 2, figsize=(20, 10)) fPhaseAdv.append(f) n_conds = len(r) colors = cm.viridis(np.linspace(0, 0.95, n_conds)) ##### ## 1. now for plotting: first, response amplitude (with linear contrast) ##### plot_cons = np.linspace(0, 1, 100) mod_fit = rvc_model(opt_params_rvc[0], opt_params_rvc[1], opt_params_rvc[2], plot_cons) ax = plt.subplot(2, 2, 1) plot_amp = adj_means plt_measured = ax.scatter(allCons[con_inds], plot_amp, s=100, color=colors, label='ph. corr') plt_og = ax.plot(allCons[con_inds], r, linestyle='None', marker='o', markeredgecolor='k', markerfacecolor='None', alpha=0.5, label='vec. mean') plt_fit = ax.plot(plot_cons, mod_fit, linestyle='--', color='k', label='rvc fit') ax.set_xlabel('contrast') ax.set_ylabel('response (f1)') ax.set_title('response versus contrast') ax.legend(loc='upper left') # also summarize the model fit on this plot ymax = np.maximum(np.max(r), np.max(mod_fit)) plt.text(0.8, 0.30 * ymax, 'b: %.2f' % (opt_params_rvc[0]), fontsize=12, horizontalalignment='center', verticalalignment='center') plt.text(0.8, 0.20 * ymax, 'slope:%.2f' % (opt_params_rvc[1]), fontsize=12, horizontalalignment='center', verticalalignment='center') plt.text(0.8, 0.10 * ymax, 'c0: %.2f' % (opt_params_rvc[2]), fontsize=12, horizontalalignment='center', verticalalignment='center') plt.text(0.8, 0.0 * ymax, 'con gain: %.2f' % (con_gain), fontsize=12, horizontalalignment='center', verticalalignment='center') ##### ## 3. then the fit/plot of phase as a function of ampltude ##### plot_amps = np.linspace(0, np.max(r), 100) mod_fit = phAdv_model(opt_params_phAdv[0], opt_params_phAdv[1], plot_amps) ax = plt.subplot(2, 1, 2) plt_measured = ax.scatter(r, th, s=100, color=colors, clip_on=False, label='vec. mean') plt_fit = ax.plot(plot_amps, mod_fit, linestyle='--', color='k', clip_on=False, label='phAdv model') ax.set_xlabel('response amplitude') if phAdv_set_ylim: ax.set_ylim([0, 360]) ax.set_ylabel('response phase') ax.set_title('phase advance with amplitude') ax.legend(loc='upper left') ## and again, summarize the model fit on the plot xmax = np.maximum(np.max(r), np.max(plot_amps)) ymin = np.minimum(np.min(th), np.min(mod_fit)) ymax = np.maximum(np.max(th), np.max(mod_fit)) yrange = ymax - ymin if phAdv_set_ylim: if mod_fit[-1] > 260: # then start from ymin and go dwn start, sign = mod_fit[-1] - 30, -1 else: start, sign = mod_fit[-1] + 30, 1 plt.text(0.9 * xmax, start + 1 * 30 * sign, 'phi0: %.2f' % (opt_params_phAdv[0]), fontsize=12, horizontalalignment='center', verticalalignment='center') plt.text(0.9 * xmax, start + 2 * 30 * sign, 'slope:%.2f' % (opt_params_phAdv[1]), fontsize=12, horizontalalignment='center', verticalalignment='center') plt.text(0.9 * xmax, start + 3 * 30 * sign, 'phase advance: %.2f ms' % (ph_adv), fontsize=12, horizontalalignment='center', verticalalignment='center') else: plt.text(0.8 * xmax, ymin + 0.25 * yrange, 'phi0: %.2f' % (opt_params_phAdv[0]), fontsize=12, horizontalalignment='center', verticalalignment='center') plt.text(0.8 * xmax, ymin + 0.15 * yrange, 'slope:%.2f' % (opt_params_phAdv[1]), fontsize=12, horizontalalignment='center', verticalalignment='center') plt.text(0.8 * xmax, ymin + 0.05 * yrange, 'phase advance: %.2f ms' % (ph_adv), fontsize=12, horizontalalignment='center', verticalalignment='center') #center_phi = lambda ph1, ph2: np.arcsin(np.sin(np.deg2rad(ph1) - np.deg2rad(ph2))); ##### ## 2. now the polar plot of resp/phase together ##### ax = plt.subplot(2, 2, 2, projection='polar') th_center = np.rad2deg(np.radians(-90) + np.radians(th[np.argmax(r)])) # "anchor" to the phase at the highest amplitude response #data_centered = center_phi(th, th_center); #model_centered = center_phi(mod_fit, th_center); #ax.scatter(data_centered, r, s=50, color=colors); #ax.plot(model_centered, plot_amps, linestyle='--', color='k'); data_centered = np.mod(th - th_center, 360) model_centered = np.mod(mod_fit - th_center, 360) ax.scatter(np.deg2rad(data_centered), r, s=50, color=colors) ax.plot(np.deg2rad(model_centered), plot_amps, linestyle='--', color='k') ax.set_ylim(0, 1.25 * np.max(r)) ax.set_title('phase advance') # overall title f.subplots_adjust(wspace=0.2, hspace=0.25) f1f0_ratio = hf.compute_f1f0(data, which_cell, expInd, dp, descrFitName_f0=descrFit_f0)[0] try: f.suptitle('%s (%.2f) #%d: disp %d, sf %.2f cpd' % (dataList['unitType'][which_cell - 1], f1f0_ratio, which_cell, allDisps[disp], allSfs[sf])) except: f.suptitle('%s (%.2f) #%d: disp %d, sf %.2f cpd' % (dataList['unitArea'][which_cell - 1], f1f0_ratio, which_cell, allDisps[disp], allSfs[sf])) saveName = "/cell_%03d_d%d_phaseAdv.pdf" % (which_cell, disp) save_loc = save_base + 'summary/' full_save = os.path.dirname(str(save_loc)) if not os.path.exists(full_save): os.makedirs(full_save) pdfSv = pltSave.PdfPages(full_save + saveName) for f in fPhaseAdv: pdfSv.savefig(f) plt.close(f) pdfSv.close()
gs_std = modFit_wg[9]; # now organize the responses orgs = [hf.organize_resp(mr, expData, expInd) for mr in modResps]; #orgs = [hf.organize_modResp(mr, expData) for mr in modResps]; sfmixModResps = [org[2] for org in orgs]; allSfMixs = [org[3] for org in orgs]; # now organize the measured responses in the same way _, _, sfmixExpResp, allSfMixExp = hf.organize_resp(expData['sfm']['exp']['trial']['spikeCount'], expData, expInd); modLows = [np.nanmin(resp, axis=3) for resp in allSfMixs]; modHighs = [np.nanmax(resp, axis=3) for resp in allSfMixs]; modAvgs = [np.nanmean(resp, axis=3) for resp in allSfMixs]; modSponRates = [fit[6] for fit in modFits]; # more tabulation resp, stimVals, val_con_by_disp, _, _ = hf.tabulate_responses(expData, expInd, modResps[0]); respMean = resp[0]; respStd = resp[1]; blankMean, blankStd, _ = hf.blankResp(expData); all_disps = stimVals[0]; all_cons = stimVals[1]; all_sfs = stimVals[2]; nCons = len(all_cons); nSfs = len(all_sfs); nDisps = len(all_disps); # ### Plots
# #### determine contrasts, center spatial frequency, dispersions data = cellStruct['sfm']['exp']['trial'] ignore, modRespAll = mod_resp.SFMGiveBof(modParamsCurr, cellStruct, normType=norm_type, lossType=lossType, expInd=expInd) print('norm type %02d' % (norm_type)) if norm_type == 2: gs_mean = modParamsCurr[1] # guaranteed to exist after call to .SFMGiveBof, if norm_type == 2 gs_std = modParamsCurr[2] # guaranteed to exist ... resp, stimVals, val_con_by_disp, validByStimVal, modResp = hf.tabulate_responses( cellStruct, expInd, modRespAll) blankMean, blankStd, _ = hf.blankResp(cellStruct) modBlankMean = modParamsCurr[6] # late additive noise is the baseline of the model # all responses on log ordinate (y axis) should be baseline subtracted all_disps = stimVals[0] all_cons = stimVals[1] all_sfs = stimVals[2] nCons = len(all_cons) nSfs = len(all_sfs) nDisps = len(all_disps) # #### Unpack responses
# #### Load descriptive model fits, comp. model fits descrFits = np.load(str(dataPath + 'descrFits.npy'), encoding='latin1').item() descrFits = descrFits[which_cell - 1]['params'] # just get this cell modParams = np.load(str(dataPath + fitListName), encoding='latin1').item() modParamsCurr = modParams[which_cell - 1]['params'] # ### Organize data # #### determine contrasts, center spatial frequency, dispersions data = cellStruct['sfm']['exp']['trial'] modRespAll = model_responses.SFMGiveBof(modParamsCurr, cellStruct)[1] resp, stimVals, val_con_by_disp, validByStimVal, modResp = helper_fcns.tabulate_responses( cellStruct, modRespAll) blankMean, blankStd, _ = helper_fcns.blankResp(cellStruct) # all responses on log ordinate (y axis) should be baseline subtracted all_disps = stimVals[0] all_cons = stimVals[1] all_sfs = stimVals[2] nCons = len(all_cons) nSfs = len(all_sfs) nDisps = len(all_disps) # #### Unpack responses respMean = resp[0] respStd = resp[1]
def fit_descr(cell_num, data_loc, n_repeats = 4, loss_type = 1): nParam = 5; if loss_type == 1: loss_str = '_lsq.npy'; elif loss_type == 2: loss_str = '_sqrt.npy'; elif loss_type == 3: loss_str = '_poiss.npy'; # load cell information dataList = hfunc.np_smart_load(data_loc + 'dataList.npy'); if os.path.isfile(data_loc + 'descrFits' + loss_str): descrFits = hfunc.np_smart_load(data_loc + 'descrFits' + loss_str); else: descrFits = dict(); data = hfunc.np_smart_load(data_loc + dataList['unitName'][cell_num-1] + '_sfm.npy'); print('Doing the work, now'); to_unpack = hfunc.tabulate_responses(data); [respMean, respVar, predMean, predVar] = to_unpack[0]; [all_disps, all_cons, all_sfs] = to_unpack[1]; val_con_by_disp = to_unpack[2]; nDisps = len(all_disps); nCons = len(all_cons); if cell_num-1 in descrFits: bestNLL = descrFits[cell_num-1]['NLL']; currParams = descrFits[cell_num-1]['params']; else: # set values to NaN... bestNLL = np.ones((nDisps, nCons)) * np.nan; currParams = np.ones((nDisps, nCons, nParam)) * np.nan; for family in range(nDisps): for con in range(nCons): if con not in val_con_by_disp[family]: continue; print('.'); # set initial parameters - a range from which we will pick! base_rate = hfunc.blankResp(data)[0]; if base_rate <= 3: range_baseline = (0, 3); else: range_baseline = (0.5 * base_rate, 1.5 * base_rate); valid_sf_inds = ~np.isnan(respMean[family, :, con]); max_resp = np.amax(respMean[family, valid_sf_inds, con]); range_amp = (0.5 * max_resp, 1.5); theSfCents = all_sfs[valid_sf_inds]; max_sf_index = np.argmax(respMean[family, valid_sf_inds, con]); # what sf index gives peak response? mu_init = theSfCents[max_sf_index]; if max_sf_index == 0: # i.e. smallest SF center gives max response... range_mu = (mu_init/2,theSfCents[max_sf_index + 3]); elif max_sf_index+1 == len(theSfCents): # i.e. highest SF center is max range_mu = (theSfCents[max_sf_index-2], mu_init); else: range_mu = (theSfCents[max_sf_index-1], theSfCents[max_sf_index+1]); # go +-1 indices from center log_bw_lo = 0.75; # 0.75 octave bandwidth... log_bw_hi = 2; # 2 octave bandwidth... denom_lo = hfunc.bw_log_to_lin(log_bw_lo, mu_init)[0]; # get linear bandwidth denom_hi = hfunc.bw_log_to_lin(log_bw_hi, mu_init)[0]; # get lin. bw (cpd) range_denom = (denom_lo, denom_hi); # don't want 0 in sigma # set bounds for parameters min_bw = 1/4; max_bw = 10; # ranges in octave bandwidth bound_baseline = (0, max_resp); bound_range = (0, 1.5*max_resp); bound_mu = (0.01, 10); bound_sig = (np.maximum(0.1, min_bw/(2*np.sqrt(2*np.log(2)))), max_bw/(2*np.sqrt(2*np.log(2)))); # Gaussian at half-height all_bounds = (bound_baseline, bound_range, bound_mu, bound_sig, bound_sig); for n_try in range(n_repeats): # pick initial params init_base = hfunc.random_in_range(range_baseline); init_amp = hfunc.random_in_range(range_amp); init_mu = hfunc.random_in_range(range_mu); init_sig_left = hfunc.random_in_range(range_denom); init_sig_right = hfunc.random_in_range(range_denom); init_params = [init_base, init_amp, init_mu, init_sig_left, init_sig_right]; # choose optimization method if np.mod(n_try, 2) == 0: methodStr = 'L-BFGS-B'; else: methodStr = 'TNC'; obj = lambda params: descr_loss(params, data, family, con, loss_type); wax = opt.minimize(obj, init_params, method=methodStr, bounds=all_bounds); # compare NLL = wax['fun']; params = wax['x']; if np.isnan(bestNLL[family, con]) or NLL < bestNLL[family, con] or invalid(currParams[family, con, :], all_bounds): bestNLL[family, con] = NLL; currParams[family, con, :] = params; # update stuff - load again in case some other run has saved/made changes if os.path.isfile(data_loc + 'descrFits' + loss_str): print('reloading descrFits...'); descrFits = hfunc.np_smart_load(data_loc + 'descrFits' + loss_str); if cell_num-1 not in descrFits: descrFits[cell_num-1] = dict(); descrFits[cell_num-1]['NLL'] = bestNLL; descrFits[cell_num-1]['params'] = currParams; np.save(data_loc + 'descrFits' + loss_str, descrFits); print('saving for cell ' + str(cell_num));
else: # otherwise, if it's complex, just get F0 respMeasure = 0 spikes = hf.get_spikes(expData, get_f0=1, rvcFits=None, expInd=expInd) rates = False # get_spikes without rvcFits is directly from spikeCount, which is counts, not rates! baseline = hf.blankResp(expData, expInd)[0] # we'll plot the spontaneous rate # why mult by stimDur? well, spikes are not rates but baseline is, so we convert baseline to count (i.e. not rate, too) spikes = spikes - baseline * hf.get_exp_params(expInd).stimDur #print('###\nGetting spikes (data): rates? %d\n###' % rates); _, _, _, respAll = hf.organize_resp(spikes, expData, expInd, respsAsRate=rates) # only using respAll to get variance measures resps_data, stimVals, val_con_by_disp, _, _ = hf.tabulate_responses( expData, expInd, overwriteSpikes=spikes, respsAsRates=rates, modsAsRate=rates) if fitList is None: resps = resps_data # otherwise, we'll still keep resps_data for reference elif fitList is not None: # OVERWRITE the data with the model spikes! if use_mod_resp == 1: curr_fit = fitList[which_cell - 1]['params'] modResp = mod_resp.SFMGiveBof(curr_fit, S, normType=fitType, lossType=lossType, expInd=expInd, cellNum=which_cell,
def plot_save_superposition(which_cell, expDir, use_mod_resp=0, fitType=2, excType=1, useHPCfit=1, conType=None, lgnFrontEnd=None, force_full=1, f1_expCutoff=2, to_save=1): if use_mod_resp == 2: rvcAdj = -1; # this means vec corrected F1, not phase adjustment F1... _applyLGNtoNorm = 0; # don't apply the LGN front-end to the gain control weights recenter_norm = 1; newMethod = 1; # yes, use the "new" method for mrpt (not that new anymore, as of 21.03) lossType = 1; # sqrt _sigmoidSigma = 5; basePath = os.getcwd() + '/' if 'pl1465' in basePath or useHPCfit: loc_str = 'HPC'; else: loc_str = ''; rvcName = 'rvcFits%s_220531' % loc_str if expDir=='LGN/' else 'rvcFits%s_220609' % loc_str rvcFits = None; # pre-define this as None; will be overwritten if available/needed if expDir == 'altExp/': # we don't adjust responses there... rvcName = None; dFits_base = 'descrFits%s_220609' % loc_str if expDir=='LGN/' else 'descrFits%s_220631' % loc_str if use_mod_resp == 1: rvcName = None; # Use NONE if getting model responses, only if excType == 1: fitBase = 'fitList_200417'; elif excType == 2: fitBase = 'fitList_200507'; lossType = 1; # sqrt fitList_nm = hf.fitList_name(fitBase, fitType, lossType=lossType); elif use_mod_resp == 2: rvcName = None; # Use NONE if getting model responses, only if excType == 1: fitBase = 'fitList%s_210308_dG' % loc_str if recenter_norm: #fitBase = 'fitList%s_pyt_210312_dG' % loc_str fitBase = 'fitList%s_pyt_210331_dG' % loc_str elif excType == 2: fitBase = 'fitList%s_pyt_210310' % loc_str if recenter_norm: #fitBase = 'fitList%s_pyt_210312' % loc_str fitBase = 'fitList%s_pyt_210331' % loc_str fitList_nm = hf.fitList_name(fitBase, fitType, lossType=lossType, lgnType=lgnFrontEnd, lgnConType=conType, vecCorrected=-rvcAdj); # ^^^ EDIT rvc/descrFits/fitList names here; ############ # Before any plotting, fix plotting paramaters ############ plt.style.use('https://raw.githubusercontent.com/paul-levy/SF_diversity/master/paul_plt_style.mplstyle'); from matplotlib import rcParams rcParams['font.size'] = 20; rcParams['pdf.fonttype'] = 42 # should be 42, but there are kerning issues rcParams['ps.fonttype'] = 42 # should be 42, but there are kerning issues rcParams['lines.linewidth'] = 2.5; rcParams['axes.linewidth'] = 1.5; rcParams['lines.markersize'] = 8; # this is in style sheet, just being explicit rcParams['lines.markeredgewidth'] = 0; # no edge, since weird tings happen then rcParams['xtick.major.size'] = 15 rcParams['xtick.minor.size'] = 5; # no minor ticks rcParams['ytick.major.size'] = 15 rcParams['ytick.minor.size'] = 0; # no minor ticks rcParams['xtick.major.width'] = 2 rcParams['xtick.minor.width'] = 2; rcParams['ytick.major.width'] = 2 rcParams['ytick.minor.width'] = 0 rcParams['font.style'] = 'oblique'; rcParams['font.size'] = 20; ############ # load everything ############ dataListNm = hf.get_datalist(expDir, force_full=force_full); descrFits_f0 = None; dLoss_num = 2; # see hf.descrFit_name/descrMod_name/etc for details if expDir == 'LGN/': rvcMod = 0; dMod_num = 1; rvcDir = 1; vecF1 = -1; else: rvcMod = 1; # i.e. Naka-rushton (1) dMod_num = 3; # d-dog-s rvcDir = None; # None if we're doing vec-corrected if expDir == 'altExp/': vecF1 = 0; else: vecF1 = 1; dFits_mod = hf.descrMod_name(dMod_num) descrFits_name = hf.descrFit_name(lossType=dLoss_num, descrBase=dFits_base, modelName=dFits_mod, phAdj=1 if vecF1==-1 else None); ## now, let it run dataPath = basePath + expDir + 'structures/' save_loc = basePath + expDir + 'figures/' save_locSuper = save_loc + 'superposition_220713/' if use_mod_resp == 1: save_locSuper = save_locSuper + '%s/' % fitBase dataList = hf.np_smart_load(dataPath + dataListNm); print('Trying to load descrFits at: %s' % (dataPath + descrFits_name)); descrFits = hf.np_smart_load(dataPath + descrFits_name); if use_mod_resp == 1 or use_mod_resp == 2: fitList = hf.np_smart_load(dataPath + fitList_nm); else: fitList = None; if not os.path.exists(save_locSuper): os.makedirs(save_locSuper) cells = np.arange(1, 1+len(dataList['unitName'])) zr_rm = lambda x: x[x>0]; # more flexible - only get values where x AND z are greater than some value "gt" (e.g. 0, 1, 0.4, ...) zr_rm_pair = lambda x, z, gt: [x[np.logical_and(x>gt, z>gt)], z[np.logical_and(x>gt, z>gt)]]; # zr_rm_pair = lambda x, z: [x[np.logical_and(x>0, z>0)], z[np.logical_and(x>0, z>0)]] if np.logical_and(x!=[], z!=[])==True else [], []; # here, we'll save measures we are going use for analysis purpose - e.g. supperssion index, c50 curr_suppr = dict(); ############ ### Establish the plot, load cell-specific measures ############ nRows, nCols = 6, 2; cellName = dataList['unitName'][which_cell-1]; expInd = hf.get_exp_ind(dataPath, cellName)[0] S = hf.np_smart_load(dataPath + cellName + '_sfm.npy') expData = S['sfm']['exp']['trial']; # 0th, let's load the basic tuning characterizations AND the descriptive fit try: dfit_curr = descrFits[which_cell-1]['params'][0,-1,:]; # single grating, highest contrast except: dfit_curr = None; # - then the basics try: basic_names, basic_order = dataList['basicProgName'][which_cell-1], dataList['basicProgOrder'] basics = hf.get_basic_tunings(basic_names, basic_order); except: try: # we've already put the basics in the data structure... (i.e. post-sorting 2021 data) basic_names = ['','','','','']; basic_order = ['rf', 'sf', 'tf', 'rvc', 'ori']; # order doesn't matter if they are already loaded basics = hf.get_basic_tunings(basic_names, basic_order, preProc=S, reducedSave=True) except: basics = None; ### TEMPORARY: save the "basics" in curr_suppr; should live on its own, though; TODO curr_suppr['basics'] = basics; try: oriBW, oriCV = basics['ori']['bw'], basics['ori']['cv']; except: oriBW, oriCV = np.nan, np.nan; try: tfBW = basics['tf']['tfBW_oct']; except: tfBW = np.nan; try: suprMod = basics['rfsize']['suprInd_model']; except: suprMod = np.nan; try: suprDat = basics['rfsize']['suprInd_data']; except: suprDat = np.nan; try: cellType = dataList['unitType'][which_cell-1]; except: # TODO: note, this is dangerous; thus far, only V1 cells don't have 'unitType' field in dataList, so we can safely do this cellType = 'V1'; ############ ### compute f1f0 ratio, and load the corresponding F0 or F1 responses ############ f1f0_rat = hf.compute_f1f0(expData, which_cell, expInd, dataPath, descrFitName_f0=descrFits_f0)[0]; curr_suppr['f1f0'] = f1f0_rat; respMeasure = 1 if f1f0_rat > 1 else 0; if vecF1 == 1: # get the correct, adjusted F1 response if expInd > f1_expCutoff and respMeasure == 1: respOverwrite = hf.adjust_f1_byTrial(expData, expInd); else: respOverwrite = None; if (respMeasure == 1 or expDir == 'LGN/') and expDir != 'altExp/' : # i.e. if we're looking at a simple cell, then let's get F1 if vecF1 == 1: spikes_byComp = respOverwrite # then, sum up the valid components per stimulus component allCons = np.vstack(expData['con']).transpose(); blanks = np.where(allCons==0); spikes_byComp[blanks] = 0; # just set it to 0 if that component was blank during the trial else: if rvcName is not None: try: rvcFits = hf.get_rvc_fits(dataPath, expInd, which_cell, rvcName=rvcName, rvcMod=rvcMod, direc=rvcDir, vecF1=vecF1); except: rvcFits = None; else: rvcFits = None spikes_byComp = hf.get_spikes(expData, get_f0=0, rvcFits=rvcFits, expInd=expInd); spikes = np.array([np.sum(x) for x in spikes_byComp]); rates = True if vecF1 == 0 else False; # when we get the spikes from rvcFits, they've already been converted into rates (in hf.get_all_fft) baseline = None; # f1 has no "DC", yadig? else: # otherwise, if it's complex, just get F0 respMeasure = 0; spikes = hf.get_spikes(expData, get_f0=1, rvcFits=None, expInd=expInd); rates = False; # get_spikes without rvcFits is directly from spikeCount, which is counts, not rates! baseline = hf.blankResp(expData, expInd)[0]; # we'll plot the spontaneous rate # why mult by stimDur? well, spikes are not rates but baseline is, so we convert baseline to count (i.e. not rate, too) spikes = spikes - baseline*hf.get_exp_params(expInd).stimDur; #print('###\nGetting spikes (data): rates? %d\n###' % rates); _, _, _, respAll = hf.organize_resp(spikes, expData, expInd, respsAsRate=rates); # only using respAll to get variance measures resps_data, stimVals, val_con_by_disp, _, _ = hf.tabulate_responses(expData, expInd, overwriteSpikes=spikes, respsAsRates=rates, modsAsRate=rates); if fitList is None: resps = resps_data; # otherwise, we'll still keep resps_data for reference elif fitList is not None: # OVERWRITE the data with the model spikes! if use_mod_resp == 1: curr_fit = fitList[which_cell-1]['params']; modResp = mod_resp.SFMGiveBof(curr_fit, S, normType=fitType, lossType=lossType, expInd=expInd, cellNum=which_cell, excType=excType)[1]; if f1f0_rat < 1: # then subtract baseline.. modResp = modResp - baseline*hf.get_exp_params(expInd).stimDur; # now organize the responses resps, stimVals, val_con_by_disp, _, _ = hf.tabulate_responses(expData, expInd, overwriteSpikes=modResp, respsAsRates=False, modsAsRate=False); elif use_mod_resp == 2: # then pytorch model! resp_str = hf_sf.get_resp_str(respMeasure) curr_fit = fitList[which_cell-1][resp_str]['params']; model = mrpt.sfNormMod(curr_fit, expInd=expInd, excType=excType, normType=fitType, lossType=lossType, lgnFrontEnd=lgnFrontEnd, newMethod=newMethod, lgnConType=conType, applyLGNtoNorm=_applyLGNtoNorm) ### get the vec-corrected responses, if applicable if expInd > f1_expCutoff and respMeasure == 1: respOverwrite = hf.adjust_f1_byTrial(expData, expInd); else: respOverwrite = None; dw = mrpt.dataWrapper(expData, respMeasure=respMeasure, expInd=expInd, respOverwrite=respOverwrite); # respOverwrite defined above (None if DC or if expInd=-1) modResp = model.forward(dw.trInf, respMeasure=respMeasure, sigmoidSigma=_sigmoidSigma, recenter_norm=recenter_norm).detach().numpy(); if respMeasure == 1: # make sure the blank components have a zero response (we'll do the same with the measured responses) blanks = np.where(dw.trInf['con']==0); modResp[blanks] = 0; # next, sum up across components modResp = np.sum(modResp, axis=1); # finally, make sure this fills out a vector of all responses (just have nan for non-modelled trials) nTrialsFull = len(expData['num']); modResp_full = np.nan * np.zeros((nTrialsFull, )); modResp_full[dw.trInf['num']] = modResp; if respMeasure == 0: # if DC, then subtract baseline..., as determined from data (why not model? we aren't yet calc. response to no stim, though it can be done) modResp_full = modResp_full - baseline*hf.get_exp_params(expInd).stimDur; # TODO: This is a work around for which measures are in rates vs. counts (DC vs F1, model vs data...) stimDur = hf.get_exp_params(expInd).stimDur; asRates = False; #divFactor = stimDur if asRates == 0 else 1; #modResp_full = np.divide(modResp_full, divFactor); # now organize the responses resps, stimVals, val_con_by_disp, _, _ = hf.tabulate_responses(expData, expInd, overwriteSpikes=modResp_full, respsAsRates=asRates, modsAsRate=asRates); predResps = resps[2]; respMean = resps[0]; # equivalent to resps[0]; respStd = np.nanstd(respAll, -1); # take std of all responses for a given condition # compute SEM, too findNaN = np.isnan(respAll); nonNaN = np.sum(findNaN == False, axis=-1); respSem = np.nanstd(respAll, -1) / np.sqrt(nonNaN); ############ ### first, fit a smooth function to the overall pred V measured responses ### --- from this, we can measure how each example superposition deviates from a central tendency ### --- i.e. the residual relative to the "standard" input:output relationship ############ all_resps = respMean[1:, :, :].flatten() # all disp>0 all_preds = predResps[1:, :, :].flatten() # all disp>0 # a model which allows negative fits # myFit = lambda x, t0, t1, t2: t0 + t1*x + t2*x*x; # non_nan = np.where(~np.isnan(all_preds)); # cannot fit negative values with naka-rushton... # fitz, _ = opt.curve_fit(myFit, all_preds[non_nan], all_resps[non_nan], p0=[-5, 10, 5], maxfev=5000) # naka rushton myFit = lambda x, g, expon, c50: hf.naka_rushton(x, [0, g, expon, c50]) non_neg = np.where(all_preds>0) # cannot fit negative values with naka-rushton... try: if use_mod_resp == 1: # the reference will ALWAYS be the data -- redo the above analysis for data predResps_data = resps_data[2]; respMean_data = resps_data[0]; all_resps_data = respMean_data[1:, :, :].flatten() # all disp>0 all_preds_data = predResps_data[1:, :, :].flatten() # all disp>0 non_neg_data = np.where(all_preds_data>0) # cannot fit negative values with naka-rushton... fitz, _ = opt.curve_fit(myFit, all_preds_data[non_neg_data], all_resps_data[non_neg_data], p0=[100, 2, 25], maxfev=5000) else: fitz, _ = opt.curve_fit(myFit, all_preds[non_neg], all_resps[non_neg], p0=[100, 2, 25], maxfev=5000) rel_c50 = np.divide(fitz[-1], np.max(all_preds[non_neg])); except: fitz = None; rel_c50 = -99; ############ ### organize stimulus information ############ all_disps = stimVals[0]; all_cons = stimVals[1]; all_sfs = stimVals[2]; nCons = len(all_cons); nSfs = len(all_sfs); nDisps = len(all_disps); maxResp = np.maximum(np.nanmax(respMean), np.nanmax(predResps)); # by disp clrs_d = cm.viridis(np.linspace(0,0.75,nDisps-1)); lbls_d = ['disp: %s' % str(x) for x in range(nDisps)]; # by sf val_sfs = hf.get_valid_sfs(S, disp=1, con=val_con_by_disp[1][0], expInd=expInd) # pick clrs_sf = cm.viridis(np.linspace(0,.75,len(val_sfs))); lbls_sf = ['sf: %.2f' % all_sfs[x] for x in val_sfs]; # by con val_con = all_cons; clrs_con = cm.viridis(np.linspace(0,.75,len(val_con))); lbls_con = ['con: %.2f' % x for x in val_con]; ############ ### create the figure ############ fSuper, ax = plt.subplots(nRows, nCols, figsize=(10*nCols, 8*nRows)) sns.despine(fig=fSuper, offset=10) allMix = []; allSum = []; ### plot reference tuning [row 1 (i.e. 2nd row)] ## on the right, SF tuning (high contrast) sfRef = hf.nan_rm(respMean[0, :, -1]); # high contrast tuning ax[1, 1].plot(all_sfs, sfRef, 'k-', marker='o', label='ref. tuning (d0, high con)', clip_on=False) ax[1, 1].set_xscale('log') ax[1, 1].set_xlim((0.1, 10)); ax[1, 1].set_xlabel('sf (c/deg)') ax[1, 1].set_ylabel('response (spikes/s)') ax[1, 1].set_ylim((-5, 1.1*np.nanmax(sfRef))); ax[1, 1].legend(fontsize='x-small'); ##### ## then on the left, RVC (peak SF) ##### sfPeak = np.argmax(sfRef); # stupid/simple, but just get the rvc for the max response v_cons_single = val_con_by_disp[0] rvcRef = hf.nan_rm(respMean[0, sfPeak, v_cons_single]); # now, if possible, let's also plot the RVC fit if rvcFits is not None: rvcFits = hf.get_rvc_fits(dataPath, expInd, which_cell, rvcName=rvcName, rvcMod=rvcMod); rel_rvc = rvcFits[0]['params'][sfPeak]; # we get 0 dispersion, peak SF plt_cons = np.geomspace(all_cons[0], all_cons[-1], 50); c50, pk = hf.get_c50(rvcMod, rel_rvc), rvcFits[0]['conGain'][sfPeak]; c50_emp, c50_eval = hf.c50_empirical(rvcMod, rel_rvc); # determine c50 by optimization, numerical approx. if rvcMod == 0: rvc_mod = hf.get_rvc_model(); rvcmodResp = rvc_mod(*rel_rvc, plt_cons); else: # i.e. mod=1 or mod=2 rvcmodResp = hf.naka_rushton(plt_cons, rel_rvc); if baseline is not None: rvcmodResp = rvcmodResp - baseline; ax[1, 0].plot(plt_cons, rvcmodResp, 'k--', label='rvc fit (c50=%.2f, gain=%0f)' %(c50, pk)) # and save it curr_suppr['c50'] = c50; curr_suppr['conGain'] = pk; curr_suppr['c50_emp'] = c50_emp; curr_suppr['c50_emp_eval'] = c50_eval else: curr_suppr['c50'] = np.nan; curr_suppr['conGain'] = np.nan; curr_suppr['c50_emp'] = np.nan; curr_suppr['c50_emp_eval'] = np.nan; ax[1, 0].plot(all_cons[v_cons_single], rvcRef, 'k-', marker='o', label='ref. tuning (d0, peak SF)', clip_on=False) # ax[1, 0].set_xscale('log') ax[1, 0].set_xlabel('contrast (%)'); ax[1, 0].set_ylabel('response (spikes/s)') ax[1, 0].set_ylim((-5, 1.1*np.nanmax(rvcRef))); ax[1, 0].legend(fontsize='x-small'); # plot the fitted model on each axis pred_plt = np.linspace(0, np.nanmax(all_preds), 100); if fitz is not None: ax[0, 0].plot(pred_plt, myFit(pred_plt, *fitz), 'r--', label='fit') ax[0, 1].plot(pred_plt, myFit(pred_plt, *fitz), 'r--', label='fit') for d in range(nDisps): if d == 0: # we don't care about single gratings! dispRats = []; continue; v_cons = np.array(val_con_by_disp[d]); n_v_cons = len(v_cons); # plot split out by each contrast [0,1] for c in reversed(range(n_v_cons)): v_sfs = hf.get_valid_sfs(S, d, v_cons[c], expInd) for s in v_sfs: mixResp = respMean[d, s, v_cons[c]]; allMix.append(mixResp); sumResp = predResps[d, s, v_cons[c]]; allSum.append(sumResp); # print('condition: d(%d), c(%d), sf(%d):: pred(%.2f)|real(%.2f)' % (d, v_cons[c], s, sumResp, mixResp)) # PLOT in by-disp panel if c == 0 and s == v_sfs[0]: ax[0, 0].plot(sumResp, mixResp, 'o', color=clrs_d[d-1], label=lbls_d[d], clip_on=False) else: ax[0, 0].plot(sumResp, mixResp, 'o', color=clrs_d[d-1], clip_on=False) # PLOT in by-sf panel sfInd = np.where(np.array(v_sfs) == s)[0][0]; # will only be one entry, so just "unpack" try: if d == 1 and c == 0: ax[0, 1].plot(sumResp, mixResp, 'o', color=clrs_sf[sfInd], label=lbls_sf[sfInd], clip_on=False); else: ax[0, 1].plot(sumResp, mixResp, 'o', color=clrs_sf[sfInd], clip_on=False); except: pass; #pdb.set_trace(); # plot baseline, if f0... # if baseline is not None: # [ax[0, i].axhline(baseline, linestyle='--', color='k', label='spon. rate') for i in range(2)]; # plot averaged across all cons/sfs (i.e. average for the whole dispersion) [1,0] mixDisp = respMean[d, :, :].flatten(); sumDisp = predResps[d, :, :].flatten(); mixDisp, sumDisp = zr_rm_pair(mixDisp, sumDisp, 0.5); curr_rats = np.divide(mixDisp, sumDisp) curr_mn = geomean(curr_rats); curr_std = np.std(np.log10(curr_rats)); # curr_rat = geomean(np.divide(mixDisp, sumDisp)); ax[2, 0].bar(d, curr_mn, yerr=curr_std, color=clrs_d[d-1]); ax[2, 0].set_yscale('log') ax[2, 0].set_ylim(0.1, 10); # ax[2, 0].yaxis.set_ticks(minorticks) dispRats.append(curr_mn); # ax[2, 0].bar(d, np.mean(np.divide(mixDisp, sumDisp)), color=clrs_d[d-1]); # also, let's plot the (signed) error relative to the fit if fitz is not None: errs = mixDisp - myFit(sumDisp, *fitz); ax[3, 0].bar(d, np.mean(errs), yerr=np.std(errs), color=clrs_d[d-1]) # -- and normalized by the prediction output response errs_norm = np.divide(mixDisp - myFit(sumDisp, *fitz), myFit(sumDisp, *fitz)); ax[4, 0].bar(d, np.mean(errs_norm), yerr=np.std(errs_norm), color=clrs_d[d-1]) # and set some labels/lines, as needed if d == 1: ax[2, 0].set_xlabel('dispersion'); ax[2, 0].set_ylabel('suppression ratio (linear)') ax[2, 0].axhline(1, ls='--', color='k') ax[3, 0].set_xlabel('dispersion'); ax[3, 0].set_ylabel('mean (signed) error') ax[3, 0].axhline(0, ls='--', color='k') ax[4, 0].set_xlabel('dispersion'); ax[4, 0].set_ylabel('mean (signed) error -- as frac. of fit prediction') ax[4, 0].axhline(0, ls='--', color='k') curr_suppr['supr_disp'] = dispRats; ### plot averaged across all cons/disps sfInds = []; sfRats = []; sfRatStd = []; sfErrs = []; sfErrsStd = []; sfErrsInd = []; sfErrsIndStd = []; sfErrsRat = []; sfErrsRatStd = []; curr_errNormFactor = []; for s in range(len(val_sfs)): try: # not all sfs will have legitmate values; # only get mixtures (i.e. ignore single gratings) mixSf = respMean[1:, val_sfs[s], :].flatten(); sumSf = predResps[1:, val_sfs[s], :].flatten(); mixSf, sumSf = zr_rm_pair(mixSf, sumSf, 0.5); rats_curr = np.divide(mixSf, sumSf); sfInds.append(s); sfRats.append(geomean(rats_curr)); sfRatStd.append(np.std(np.log10(rats_curr))); if fitz is not None: #curr_NR = myFit(sumSf, *fitz); # unvarnished curr_NR = np.maximum(myFit(sumSf, *fitz), 0.5); # thresholded at 0.5... curr_err = mixSf - curr_NR; sfErrs.append(np.mean(curr_err)); sfErrsStd.append(np.std(curr_err)) curr_errNorm = np.divide(mixSf - curr_NR, mixSf + curr_NR); sfErrsInd.append(np.mean(curr_errNorm)); sfErrsIndStd.append(np.std(curr_errNorm)) curr_errRat = np.divide(mixSf, curr_NR); sfErrsRat.append(np.mean(curr_errRat)); sfErrsRatStd.append(np.std(curr_errRat)); curr_normFactors = np.array(curr_NR) curr_errNormFactor.append(geomean(curr_normFactors[curr_normFactors>0])); else: sfErrs.append([]); sfErrsStd.append([]); sfErrsInd.append([]); sfErrsIndStd.append([]); sfErrsRat.append([]); sfErrsRatStd.append([]); curr_errNormFactor.append([]); except: pass # get the offset/scale of the ratio so that we can plot a rescaled/flipped version of # the high con/single grat tuning for reference...does the suppression match the response? offset, scale = np.nanmax(sfRats), np.nanmax(sfRats) - np.nanmin(sfRats); sfRef = hf.nan_rm(respMean[0, val_sfs, -1]); # high contrast tuning sfRefShift = offset - scale * (sfRef/np.nanmax(sfRef)) ax[2,1].scatter(all_sfs[val_sfs][sfInds], sfRats, color=clrs_sf[sfInds], clip_on=False) ax[2,1].errorbar(all_sfs[val_sfs][sfInds], sfRats, sfRatStd, color='k', linestyle='-', clip_on=False, label='suppression tuning') # ax[2,1].plot(all_sfs[val_sfs][sfInds], sfRats, 'k-', clip_on=False, label='suppression tuning') ax[2,1].plot(all_sfs[val_sfs], sfRefShift, 'k--', label='ref. tuning', clip_on=False) ax[2,1].axhline(1, ls='--', color='k') ax[2,1].set_xlabel('sf (cpd)') ax[2,1].set_xscale('log') ax[2,1].set_xlim((0.1, 10)); #ax[2,1].set_xlim((np.min(all_sfs), np.max(all_sfs))); ax[2,1].set_ylabel('suppression ratio'); ax[2,1].set_yscale('log') #ax[2,1].yaxis.set_ticks(minorticks) ax[2,1].set_ylim(0.1, 10); ax[2,1].legend(fontsize='x-small'); curr_suppr['supr_sf'] = sfRats; ### residuals from fit of suppression if fitz is not None: # mean signed error: and labels/plots for the error as f'n of SF ax[3,1].axhline(0, ls='--', color='k') ax[3,1].set_xlabel('sf (cpd)') ax[3,1].set_xscale('log') ax[3,1].set_xlim((0.1, 10)); #ax[3,1].set_xlim((np.min(all_sfs), np.max(all_sfs))); ax[3,1].set_ylabel('mean (signed) error'); ax[3,1].errorbar(all_sfs[val_sfs][sfInds], sfErrs, sfErrsStd, color='k', marker='o', linestyle='-', clip_on=False) # -- and normalized by the prediction output response + output respeonse val_errs = np.logical_and(~np.isnan(sfErrsRat), np.logical_and(np.array(sfErrsIndStd)>0, np.array(sfErrsIndStd) < 2)); norm_subset = np.array(sfErrsInd)[val_errs]; normStd_subset = np.array(sfErrsIndStd)[val_errs]; ax[4,1].axhline(0, ls='--', color='k') ax[4,1].set_xlabel('sf (cpd)') ax[4,1].set_xscale('log') ax[4,1].set_xlim((0.1, 10)); #ax[4,1].set_xlim((np.min(all_sfs), np.max(all_sfs))); ax[4,1].set_ylim((-1, 1)); ax[4,1].set_ylabel('error index'); ax[4,1].errorbar(all_sfs[val_sfs][sfInds][val_errs], norm_subset, normStd_subset, color='k', marker='o', linestyle='-', clip_on=False) # -- AND simply the ratio between the mixture response and the mean expected mix response (i.e. Naka-Rushton) # --- equivalent to the suppression ratio, but relative to the NR fit rather than perfect linear summation val_errs = np.logical_and(~np.isnan(sfErrsRat), np.logical_and(np.array(sfErrsRatStd)>0, np.array(sfErrsRatStd) < 2)); rat_subset = np.array(sfErrsRat)[val_errs]; ratStd_subset = np.array(sfErrsRatStd)[val_errs]; #ratStd_subset = (1/np.log(2))*np.divide(np.array(sfErrsRatStd)[val_errs], rat_subset); ax[5,1].scatter(all_sfs[val_sfs][sfInds][val_errs], rat_subset, color=clrs_sf[sfInds][val_errs], clip_on=False) ax[5,1].errorbar(all_sfs[val_sfs][sfInds][val_errs], rat_subset, ratStd_subset, color='k', linestyle='-', clip_on=False, label='suppression tuning') ax[5,1].axhline(1, ls='--', color='k') ax[5,1].set_xlabel('sf (cpd)') ax[5,1].set_xscale('log') ax[5,1].set_xlim((0.1, 10)); ax[5,1].set_ylabel('suppression ratio (wrt NR)'); ax[5,1].set_yscale('log', basey=2) # ax[2,1].yaxis.set_ticks(minorticks) ax[5,1].set_ylim(np.power(2.0, -2), np.power(2.0, 2)); ax[5,1].legend(fontsize='x-small'); # - compute the variance - and put that value on the plot errsRatVar = np.var(np.log2(sfErrsRat)[val_errs]); curr_suppr['sfRat_VAR'] = errsRatVar; ax[5,1].text(0.1, 2, 'var=%.2f' % errsRatVar); # compute the unsigned "area under curve" for the sfErrsInd, and normalize by the octave span of SF values considered val_errs = np.logical_and(~np.isnan(sfErrsRat), np.logical_and(np.array(sfErrsIndStd)>0, np.array(sfErrsIndStd) < 2)); val_x = all_sfs[val_sfs][sfInds][val_errs]; ind_var = np.var(np.array(sfErrsInd)[val_errs]); curr_suppr['sfErrsInd_VAR'] = ind_var; # - and put that value on the plot ax[4,1].text(0.1, -0.25, 'var=%.3f' % ind_var); else: curr_suppr['sfErrsInd_VAR'] = np.nan curr_suppr['sfRat_VAR'] = np.nan ######### ### NOW, let's evaluate the derivative of the SF tuning curve and get the correlation with the errors ######### mod_sfs = np.geomspace(all_sfs[0], all_sfs[-1], 1000); mod_resp = hf.get_descrResp(dfit_curr, mod_sfs, DoGmodel=dMod_num); deriv = np.divide(np.diff(mod_resp), np.diff(np.log10(mod_sfs))) deriv_norm = np.divide(deriv, np.maximum(np.nanmax(deriv), np.abs(np.nanmin(deriv)))); # make the maximum response 1 (or -1) # - then, what indices to evaluate for comparing with sfErr? errSfs = all_sfs[val_sfs][sfInds]; mod_inds = [np.argmin(np.square(mod_sfs-x)) for x in errSfs]; deriv_norm_eval = deriv_norm[mod_inds]; # -- plot on [1, 1] (i.e. where the data is) ax[1,1].plot(mod_sfs, mod_resp, 'k--', label='fit (g)') ax[1,1].legend(); # Duplicate "twin" the axis to create a second y-axis ax2 = ax[1,1].twinx(); ax2.set_xscale('log'); # have to re-inforce log-scale? ax2.set_ylim([-1, 1]); # since the g' is normalized # make a plot with different y-axis using second axis object ax2.plot(mod_sfs[1:], deriv_norm, '--', color="red", label='g\''); ax2.set_ylabel("deriv. (normalized)",color="red") ax2.legend(); sns.despine(ax=ax2, offset=10, right=False); # -- and let's plot rescaled and shifted version in [2,1] offset, scale = np.nanmax(sfRats), np.nanmax(sfRats) - np.nanmin(sfRats); derivShift = offset - scale * (deriv_norm/np.nanmax(deriv_norm)); ax[2,1].plot(mod_sfs[1:], derivShift, 'r--', label='deriv(ref. tuning)', clip_on=False) ax[2,1].legend(fontsize='x-small'); # - then, normalize the sfErrs/sfErrsInd and compute the correlation coefficient if fitz is not None: norm_sfErr = np.divide(sfErrs, np.nanmax(np.abs(sfErrs))); norm_sfErrInd = np.divide(sfErrsInd, np.nanmax(np.abs(sfErrsInd))); # remember, sfErrsInd is normalized per condition; this is overall non_nan = np.logical_and(~np.isnan(norm_sfErr), ~np.isnan(deriv_norm_eval)) corr_nsf, corr_nsfN = np.corrcoef(deriv_norm_eval[non_nan], norm_sfErr[non_nan])[0,1], np.corrcoef(deriv_norm_eval[non_nan], norm_sfErrInd[non_nan])[0,1] curr_suppr['corr_derivWithErr'] = corr_nsf; curr_suppr['corr_derivWithErrsInd'] = corr_nsfN; ax[3,1].text(0.1, 0.25*np.nanmax(sfErrs), 'corr w/g\' = %.2f' % corr_nsf) ax[4,1].text(0.1, 0.25, 'corr w/g\' = %.2f' % corr_nsfN) else: curr_suppr['corr_derivWithErr'] = np.nan; curr_suppr['corr_derivWithErrsInd'] = np.nan; # make a polynomial fit try: hmm = np.polyfit(allSum, allMix, deg=1) # returns [a, b] in ax + b except: hmm = [np.nan]; curr_suppr['supr_index'] = hmm[0]; for j in range(1): for jj in range(nCols): ax[j, jj].axis('square') ax[j, jj].set_xlabel('prediction: sum(components) (imp/s)'); ax[j, jj].set_ylabel('mixture response (imp/s)'); ax[j, jj].plot([0, 1*maxResp], [0, 1*maxResp], 'k--') ax[j, jj].set_xlim((-5, maxResp)); ax[j, jj].set_ylim((-5, 1.1*maxResp)); ax[j, jj].set_title('Suppression index: %.2f|%.2f' % (hmm[0], rel_c50)) ax[j, jj].legend(fontsize='x-small'); fSuper.suptitle('Superposition: %s #%d [%s; f1f0 %.2f; szSupr[dt/md] %.2f/%.2f; oriBW|CV %.2f|%.2f; tfBW %.2f]' % (cellType, which_cell, cellName, f1f0_rat, suprDat, suprMod, oriBW, oriCV, tfBW)) if fitList is None: save_name = 'cell_%03d.pdf' % which_cell else: save_name = 'cell_%03d_mod%s.pdf' % (which_cell, hf.fitType_suffix(fitType)) pdfSv = pltSave.PdfPages(str(save_locSuper + save_name)); pdfSv.savefig(fSuper) pdfSv.close(); ######### ### Finally, add this "superposition" to the newest ######### if to_save: if fitList is None: from datetime import datetime suffix = datetime.today().strftime('%y%m%d') super_name = 'superposition_analysis_%s.npy' % suffix; else: super_name = 'superposition_analysis_mod%s.npy' % hf.fitType_suffix(fitType); pause_tm = 5*np.random.rand(); print('sleeping for %d secs (#%d)' % (pause_tm, which_cell)); time.sleep(pause_tm); if os.path.exists(dataPath + super_name): suppr_all = hf.np_smart_load(dataPath + super_name); else: suppr_all = dict(); suppr_all[which_cell-1] = curr_suppr; np.save(dataPath + super_name, suppr_all); return curr_suppr;
spikes = np.array([np.sum(x) for x in spikes_byComp]) rates = True # when we get the spikes from rvcFits, they've already been converted into rates (in hf.get_all_fft) baseline_sfMix = None # f1 has no "DC", yadig? else: # otherwise, if it's complex, just get F0 spikes = hf.get_spikes(expData, get_f0=1, rvcFits=None, expInd=expInd) rates = False # get_spikes without rvcFits is directly from spikeCount, which is counts, not rates! baseline_sfMix = hf.blankResp(expData, expInd)[0] # we'll plot the spontaneous rate # why mult by stimDur? well, spikes are not rates but baseline is, so we convert baseline to count (i.e. not rate, too) spikes = spikes - baseline_sfMix * hf.get_exp_params(expInd).stimDur _, _, respOrg, respAll = hf.organize_resp(spikes, expData, expInd) resps, stimVals, val_con_by_disp, _, _ = hf.tabulate_responses( expData, expInd, overwriteSpikes=spikes, respsAsRates=rates) predResps = resps[2] respMean = resps[0] # equivalent to resps[0]; respStd = np.nanstd(respAll, -1) # take std of all responses for a given condition # compute SEM, too findNaN = np.isnan(respAll) nonNaN = np.sum(findNaN == False, axis=-1) respSem = np.nanstd(respAll, -1) / np.sqrt(nonNaN) # organize stimulus values all_disps = stimVals[0] all_cons = stimVals[1] all_sfs = stimVals[2]