def get_stas(Stim, Robs, dims, valid=None, num_lags=10, plot=True): NX = dims[0] NY = dims[1] NT, NC = Robs.shape if valid is None: valid = np.arange(0, NT, 1) # create time-embedded stimulus Xstim, rinds = create_time_embedding_valid(Stim, [num_lags, NX, NY], valid) Rvalid = deepcopy(Robs[rinds, :]) NTv = Rvalid.shape[0] print('%d valid samples of %d possible' % (NTv, NT)) stas = Xstim.T @ (Rvalid - np.average(Rvalid, axis=0)) stas = np.reshape(stas, [NX * NY, num_lags, NC]) / NTv if plot: plt.figure(figsize=(10, 15)) sx, sy = U.get_subplot_dims(NC) mu = np.zeros(NC) for cc in range(NC): if plot: plt.subplot(sx, sy, cc + 1) plt.plot(np.abs(stas[:, :, cc]).T, color=[.5, .5, .5]) tlevel = np.median( np.abs(stas[:, :, cc] - np.average(stas[:, :, cc]))) * 6 mu[cc] = np.average(np.abs(stas[:, :, cc]) > tlevel) if plot: plt.axhline(tlevel, color='k') plt.title(cc) return stas
valid_eye_rad = 5.2 # degrees -- use this when looking at eye-calibration (see below) ppd = 37.50476617061 eyeX = (eyeAtFrame[:,0]-640)/ppd eyeY = (eyeAtFrame[:,1]-380)/ppd eyeCentered = np.hypot(eyeX, eyeY) < valid_eye_rad # eyeCentered = np.logical_and(eyeX < 0, eyeCentered) valid_inds = np.intersect1d(valdata, np.where(eyeCentered)[0]) # %% quick check STAS stas = ne.get_stas(Stim, Robs, [NX,NY], valid=valid_inds, num_lags=10, plot=False) plt.figure(figsize=(10,10)) NC = Robs.shape[1] sx,sy = U.get_subplot_dims(NC) sumdensity = 0 for cc in range(NC): plt.subplot(sx,sy,cc+1) bestlag = np.argmax(np.max(abs(stas[:,:,cc]),axis=0)) plt.imshow(np.reshape(stas[:,bestlag,cc], (NY,NX))) sumdensity += np.abs(stas[:,bestlag,cc]) plt.title(cc) plt.axis("off") plt.figure() plt.imshow(np.reshape(sumdensity, (NY, NX))) #%% # Cxinds = ne.crop_indx(NX, range(9,24), range(9,24)) Cxinds = ne.crop_indx(NX, range(5,25), range(5,25))
r = get_psth(Robs[:, cc], gratingDir, win) rhat0 = get_psth(yhat0[:, cc], gratingDir, win) rhat1 = get_psth(yhat1[:, cc], gratingDir, win) TC[:, cc] = np.sum(r[0], axis=0) TC1[:, cc] = np.sum(rhat0[0], axis=0) TC2[:, cc] = np.sum(rhat1[0], axis=0) plt.subplot(sx, sy, cc + 1) plt.plot(directions, TC[:, cc], 'k-o') plt.plot(directions, TC1[:, cc], 'r') plt.plot(directions, TC2[:, cc], 'b') plt.xticks(np.arange(0, 360, 90)) sns.despine(offset=0, trim=True) plt.figure() r2Base = U.r_squared(TC, TC1) r2Dir = U.r_squared(TC, TC2) plt.plot(r2Base, r2Dir, 'o') plt.xlim((0, 1)) plt.ylim((0, 1)) plt.plot(plt.xlim(), plt.xlim(), 'k') plt.title("R-squared Tuning Curve") r2Base = U.r_squared(Robs[Xi, :], yhat0[Xi, :]) r2Dir = U.r_squared(Robs[Xi, :], yhat1[Xi, :]) plt.figure() plt.plot(r2Base, r2Dir, 'o') plt.plot(plt.xlim(), plt.xlim(), 'k') plt.xlabel("Var Explained (Base)") plt.ylabel("Var Explained (Dir)")
def get_stim_model(Stim, Robs, dims, valid=None, num_lags=10, plot=True, XTreg=0.05, L1reg=5e-3, MIreg=0.1, MSCreg=10.0, Greg=0.1, Mreg=1e-4, num_subs=36, num_hid=24, num_tkern=None, Cindx=None, base_mod=None, cids=None, autoencoder=False): NX = dims[0] NY = dims[1] NT, NC = Robs.shape if valid is None: valid = np.arange(0, NT, 1) # create time-embedded stimulus Xstim, rinds = create_time_embedding_valid(Stim, [num_lags, NX, NY], valid) Rvalid = deepcopy(Robs[rinds, :]) NTv = Rvalid.shape[0] print('%d valid samples of %d possible' % (NTv, NT)) stas = Xstim.T @ (Rvalid - np.average(Rvalid, axis=0)) stas = np.reshape(stas, [NX * NY, num_lags, NC]) / NTv if plot: plt.figure(figsize=(10, 15)) sx, sy = U.get_subplot_dims(NC) mu = np.zeros(NC) for cc in range(NC): if plot: plt.subplot(sx, sy, cc + 1) plt.plot(np.abs(stas[:, :, cc]).T, color=[.5, .5, .5]) tlevel = np.median( np.abs(stas[:, :, cc] - np.average(stas[:, :, cc]))) * 4 mu[cc] = np.average(np.abs(stas[:, :, cc]) > tlevel) if plot: plt.axhline(tlevel, color='k') plt.title(cc) # threshold good STAS thresh = 0.01 if plot: plt.figure() plt.plot(mu, '-o') plt.axhline(thresh, color='k') plt.show() if cids is None: cids = np.where(mu > thresh)[0] # units to analyze print("found %d good STAs" % len(cids)) if plot: plt.figure(figsize=(10, 15)) for cc in cids: plt.subplot(sx, sy, cc + 1) bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0)) plt.imshow(np.reshape(stas[:, bestlag, cc], (NY, NX))) plt.title(cc) # index into "good" units Rvalid = Rvalid[:, cids] NC = Rvalid.shape[1] stas = stas[:, :, cids] if Cindx is None: print("Getting Crop Index") # Crop stimulus to center around RFs sumdensity = np.zeros([NX * NY]) for cc in range(NC): bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0)) sumdensity += stas[:, bestlag, cc]**2 if plot: plt.figure() plt.imshow(np.reshape(sumdensity, [NY, NX])) plt.title("Sum Density STA") # get Crop indices (TODO: debug) sumdensity = (sumdensity - np.min(sumdensity)) / (np.max(sumdensity) - np.min(sumdensity)) I = np.reshape(sumdensity, [NY, NX]) > .3 xinds = np.where(np.sum(I, axis=0) > 0)[0] yinds = np.where(np.sum(I, axis=1) > 0)[0] NX2 = np.maximum(len(xinds), len(yinds)) x0 = np.min(xinds) y0 = np.min(yinds) xinds = range(x0, x0 + NX2) yinds = range(y0, y0 + NX2) Cindx = crop_indx(NX, xinds, yinds) if plot: plt.figure() plt.imshow(np.reshape(sumdensity[Cindx], [NX2, NX2])) plt.title('Cropped') plt.show() NX2 = np.sqrt(len(Cindx)).astype(int) # make new cropped stimulus Xstim, rinds = create_time_embedding_valid(Stim[:, Cindx], [num_lags, NX2, NX2], valid) # index into Robs Rvalid = deepcopy(Robs[rinds, :]) Rvalid = Rvalid[:, cids] Rvalid = NDNutils.shift_mat_zpad(Rvalid, -1, dim=0) # get rid of first lag NC = Rvalid.shape[1] # new number of units NT = Rvalid.shape[0] print('%d valid samples of %d possible' % (NT, Stim.shape[0])) print('%d good units' % NC) # double-check STAS work with cropped stimulus stas = Xstim.T @ Rvalid stas = np.reshape(stas, [NX2 * NX2, num_lags, NC]) / NT if plot: plt.figure(figsize=(10, 15)) for cc in range(NC): plt.subplot(sx, sy, cc + 1) bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0)) plt.imshow(np.reshape(stas[:, bestlag, cc], (NX2, NX2))) plt.title(cc) plt.show() Ui, Xi = NDNutils.generate_xv_folds(NT) # fit SCAFFOLD MODEL try: if len(XTreg) == 2: d2t = XTreg[0] d2x = XTreg[1] else: d2t = XTreg[0] d2x = deepcopy(d2t) except TypeError: d2t = deepcopy(XTreg) d2x = deepcopy(XTreg) # optimizer parameters adam_params = U.def_adam_params() if not base_mod is None: side2b = base_mod.copy_model() side2b.set_regularization('d2t', d2t, layer_target=0) side2b.set_regularization('d2x', d2x, layer_target=0) side2b.set_regularization('glocal', Greg, layer_target=0) side2b.set_regularization('l1', L1reg, layer_target=0) side2b.set_regularization('max', MIreg, ffnet_target=0, layer_target=1) side2b.set_regularization('max', MSCreg, ffnet_target=1, layer_target=0) if len(side2b.networks) == 4: # includes autoencoder network input_data = [Xstim, Rvalid] else: input_data = Xstim else: # Best regularization arrived at Greg0 = 1e-1 Mreg0 = 1e-6 L1reg0 = 1e-5 if not num_tkern is None: ndn_par = NDNutils.ffnetwork_params( input_dims=[1, NX2, NX2, num_lags], layer_sizes=[num_tkern, num_subs, num_hid], layer_types=['conv', 'normal', 'normal'], ei_layers=[None, num_subs // 2, num_hid // 2], conv_filter_widths=[1], normalization=[1, 1, 1], act_funcs=['lin', 'relu', 'relu'], verbose=True, reg_list={ 'd2t': [1e-3], 'd2x': [None, XTreg], 'l1': [L1reg0, L1reg0], 'glocal': [Greg0, Greg0] }) else: ndn_par = NDNutils.ffnetwork_params( input_dims=[1, NX2, NX2, num_lags], layer_sizes=[num_subs, num_hid], layer_types=['normal', 'normal'], ei_layers=[num_subs // 2, num_hid // 2], normalization=[1, 1], act_funcs=['relu', 'relu'], verbose=True, reg_list={ 'd2t': [d2t], 'd2x': [d2x], 'l1': [L1reg0, L1reg0], 'glocal': [Greg0] }) side_par = NDNutils.ffnetwork_params(network_type='side', xstim_n=None, ffnet_n=0, layer_sizes=[NC], layer_types=['normal'], normalization=[-1], act_funcs=['softplus'], verbose=True, reg_list={'max': [Mreg0]}) side_par[ 'pos_constraints'] = True # ensures Exc and Inh mean something if autoencoder: # capturea additional variability using autoencoder auto_par = NDNutils.ffnetwork_params( input_dims=[1, NC, 1], xstim_n=[1], layer_sizes=[2, 1, NC], time_expand=[0, 15, 0], layer_types=['normal', 'temporal', 'normal'], conv_filter_widths=[None, 1, None], act_funcs=['relu', 'lin', 'lin'], normalization=[1, 1, 0], reg_list={'d2t': [None, 1e-1, None]}) add_par = NDNutils.ffnetwork_params(xstim_n=None, ffnet_n=[1, 2], layer_sizes=[NC], layer_types=['add'], act_funcs=['softplus']) side2 = NDN.NDN([ndn_par, side_par, auto_par, add_par], ffnet_out=1, noise_dist='poisson') # set output regularization on the latent side2.batch_size = adam_params['batch_size'] side2.initialize_output_reg(network_target=2, layer_target=1, reg_vals={'d2t': 1e-1}) input_data = [Xstim, Rvalid] else: side2 = NDN.NDN([ndn_par, side_par], ffnet_out=1, noise_dist='poisson') input_data = Xstim _ = side2.train(input_data=input_data, output_data=Rvalid, train_indxs=Ui, test_indxs=Xi, silent=False, learning_alg='adam', opt_params=adam_params) side2.set_regularization('glocal', Greg, layer_target=0) side2.set_regularization('l1', L1reg, layer_target=0) side2.set_regularization('max', MIreg, ffnet_target=0, layer_target=1) side2.set_regularization('max', MSCreg, ffnet_target=1, layer_target=0) side2b = side2.copy_model() _ = side2b.train(input_data=input_data, output_data=Rvalid, train_indxs=Ui, test_indxs=Xi, silent=False, learning_alg='adam', opt_params=adam_params) LLs2n = side2b.eval_models(input_data=input_data, output_data=Rvalid, data_indxs=Xi, nulladjusted=True) print(np.mean(LLs2n)) if plot: plt.hist(LLs2n) plt.xlabel('Nats/Spike') plt.show() return side2b, Xstim, Rvalid, rinds, cids, Cindx
def prep_stim_model(Stim, Robs, dims, valid=None, num_lags=10, plot=True, Cindx=None, cids=None): NX = dims[0] NY = dims[1] NT, NC = Robs.shape if valid is None: valid = np.arange(0, NT, 1) # create time-embedded stimulus Xstim, rinds = create_time_embedding_valid(Stim, [num_lags, NX, NY], valid) Rvalid = deepcopy(Robs[rinds, :]) NTv = Rvalid.shape[0] print('%d valid samples of %d possible' % (NTv, NT)) stas = Xstim.T @ (Rvalid - np.average(Rvalid, axis=0)) stas = np.reshape(stas, [NX * NY, num_lags, NC]) / NTv if plot: plt.figure(figsize=(10, 15)) sx, sy = U.get_subplot_dims(NC) mu = np.zeros(NC) for cc in range(NC): if plot: plt.subplot(sx, sy, cc + 1) plt.plot(np.abs(stas[:, :, cc]).T, color=[.5, .5, .5]) tlevel = np.median( np.abs(stas[:, :, cc] - np.average(stas[:, :, cc]))) * 4 mu[cc] = np.average(np.abs(stas[:, :, cc]) > tlevel) if plot: plt.axhline(tlevel, color='k') plt.title(cc) # threshold good STAS thresh = 0.01 if plot: plt.figure() plt.plot(mu, '-o') plt.axhline(thresh, color='k') plt.show() if cids is None: cids = np.where(mu > thresh)[0] # units to analyze print("found %d good STAs" % len(cids)) if plot: plt.figure(figsize=(10, 15)) for cc in cids: plt.subplot(sx, sy, cc + 1) bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0)) plt.imshow(np.reshape(stas[:, bestlag, cc], (NY, NX))) plt.title(cc) # index into "good" units Rvalid = Rvalid[:, cids] NC = Rvalid.shape[1] stas = stas[:, :, cids] if Cindx is None: print("Getting Crop Index") # Crop stimulus to center around RFs sumdensity = np.zeros([NX * NY]) for cc in range(NC): bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0)) sumdensity += stas[:, bestlag, cc]**2 if plot: plt.figure() plt.imshow(np.reshape(sumdensity, [NY, NX])) plt.title("Sum Density STA") # get Crop indices (TODO: debug) sumdensity = (sumdensity - np.min(sumdensity)) / (np.max(sumdensity) - np.min(sumdensity)) I = np.reshape(sumdensity, [NY, NX]) > .3 xinds = np.where(np.sum(I, axis=0) > 0)[0] yinds = np.where(np.sum(I, axis=1) > 0)[0] NX2 = np.maximum(len(xinds), len(yinds)) x0 = np.min(xinds) y0 = np.min(yinds) xinds = range(x0, x0 + NX2) yinds = range(y0, y0 + NX2) Cindx = crop_indx(NX, xinds, yinds) if plot: plt.figure() plt.imshow(np.reshape(sumdensity[Cindx], [NX2, NX2])) plt.title('Cropped') plt.show() NX2 = np.sqrt(len(Cindx)).astype(int) # make new cropped stimulus Xstim, rinds = create_time_embedding_valid(Stim[:, Cindx], [num_lags, NX2, NX2], valid) # index into Robs Rvalid = deepcopy(Robs[rinds, :]) Rvalid = Rvalid[:, cids] Rvalid = NDNutils.shift_mat_zpad(Rvalid, -1, dim=0) # get rid of first lag NC = Rvalid.shape[1] # new number of units NT = Rvalid.shape[0] print('%d valid samples of %d possible' % (NT, Stim.shape[0])) print('%d good units' % NC) # double-check STAS work with cropped stimulus stas = Xstim.T @ Rvalid stas = np.reshape(stas, [NX2 * NX2, num_lags, NC]) / NT if plot: plt.figure(figsize=(10, 15)) for cc in range(NC): plt.subplot(sx, sy, cc + 1) bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0)) plt.imshow(np.reshape(stas[:, bestlag, cc], (NX2, NX2))) plt.title(cc) plt.show() dims = (num_lags, NX2, NX2) return Xstim, Rvalid, dims, Cindx, cids
#%% ypred = model(xt) cc += 1 if cc == gd.NC: cc = 0 a = ypred[:,cc].detach().cpu().numpy() r0 = np.reshape(a, (gd.opts['num_repeats'],-1)) r = np.reshape(yt[:,cc], (gd.opts['num_repeats'],-1)) r = np.average(r, axis=0) r0 = np.average(r0, axis=0) plt.plot(r) plt.plot(r0) plt.title("cell %d" %cc) U.r_squared(np.reshape(r, (-1,1)), np.reshape(r0, (-1,1))) # %% w = model.l1.weight.detach().cpu().numpy() nfilt = w.shape[0] sx,sy = U.get_subplot_dims(nfilt) plt.figure(figsize=(10,10)) for cc in range(nfilt): plt.subplot(sx,sy,cc+1) wtmp = np.reshape(w[cc,:], (gd.num_lags, gd.NX*gd.NY)) # plt.imshow(np.reshape(w[cc,:], (gd.NX*gd.NY, gd.num_lags)), aspect='auto') plt.imshow(wtmp, aspect='auto') # plt.plot(wtmp)
# get median absolute deviation per unit med = np.median(stas, axis=0) devs = np.abs(stas - med) mad = np.median(devs, axis=0) # median absolute deviation excursions = np.mean(devs > 4*mad, axis=0) thresh = 0.01 cids = np.where(excursions>thresh)[0] plt.plot(excursions, '-o') plt.plot(cids, excursions[cids], 'o') #%% plot stas NC = gd.NC sx,sy = U.get_subplot_dims(NC) sx*=2 plt.figure(figsize=(10,20)) for cc in range(NC): plt.subplot(sx,sy,cc*2+1) wtmp = stas[:,cc].reshape((gd.num_lags, -1)) tpower = np.std(wtmp, axis=1) bestlag = np.argmax(tpower) wspace = wtmp[bestlag,:].reshape( (gd.NY, gd.NX)) plt.imshow(wspace, aspect='auto') plt.axis("off") if cc in cids: plt.title(cc) plt.subplot(sx,sy,cc*2+2) if cc in cids: plt.plot(wtmp[:,np.argmax(wspace)], '-ob')
Xstim = flat(sample['stim'].permute((0, 2, 3, 1))).detach().cpu().numpy() Robs = sample['robs'].detach().cpu().numpy() # test set sample = gab_shift_test[:] # load sample Xstim_test = flat(sample['stim'].permute((0, 2, 3, 1))).detach().cpu().numpy() Robs_test = sample['robs'].detach().cpu().numpy() NT = Xstim.shape[0] Ui, Xi = NDNutils.generate_xv_folds(NT) num_lags = gd_shift.num_lags NC = Robs.shape[1] adam_params = U.def_adam_params() #%% Make model Greg0 = 1e-1 Greg = 1e-1 Creg0 = 1 Creg = 1e-2 Mreg0 = 1e-3 Mreg = 1e-1 L1reg0 = 1e-5 Xreg = 1e-2 num_tkern = 2 num_subs = 8 # ndn_par = NDNutils.ffnetwork_params(