def fit_shifter( name='20200304_kilowf', path='/home/jake/Data/Datasets/MitchellV1FreeViewing/stim_movies/', tdownsample=2, numlags=12, lengthscale=1, stimlist=["Gabor", "Dots", "BackImage", "Grating", "FixRsvpStim"]): """ fit_shifter # Main routine. Fits shifter network simultaneously with convolutional divisive normalization layer # Inputs: name='20200304_kilowf' path='/home/jake/Data/Datasets/MitchellV1FreeViewing/stim_movies/' tdownsample=2 numlags=12 lengthscale=1 """ num_lags = numlags sessid = name # '20200304_kilowf' save_dir = '../../checkpoints/v1calibration_ls{}'.format(lengthscale) figDir = "/home/jake/Data/Repos/V1FreeViewingCode/Figures/2021_pytorchmodeling" outfile = Path(save_dir) / sessid / 'best_shifter.p' if outfile.exists(): print("fit_shifter: shifter model already fit for this session") return #%% LOAD ALL DATASETS # build tent basis for saccades n = 40 num_basis = 15 B = np.maximum( 1 - np.abs( np.expand_dims(np.asarray(np.arange(0, n)), axis=1) - np.arange(0, n, n / num_basis)) / n * num_basis, 0) t_downsample = tdownsample gd = dd.PixelDataset(sessid, stims=stimlist, stimset="Train", num_lags=num_lags, downsample_t=t_downsample, downsample_s=1, valid_eye_rad=5.2, dirname=path, include_frametime={ 'num_basis': 40, 'full_experiment': False }, include_saccades=[{ 'name': 'sacon', 'basis': B, 'offset': -20 }, { 'name': 'sacoff', 'basis': B, 'offset': 0 }], include_eyepos=True, preload=True) sample = gd[:] #%% compute STAS """ Compute STAS using einstein summation through pytorch """ # # use gabor here if it exists, else, use dots # if 'Gabor' in gd.stims: # stimuse = [i for i,s in zip(range(len(gd.stims)), gd.stims) if 'Gabor' == s][0] # elif 'Dots' in gd.stims: # stimuse = [i for i,s in zip(range(len(gd.stims)), gd.stims) if 'Dots' == s][0] # use first stim in list (automatically goes in Gabor, Dots order) index = np.where(gd.stim_indices == 0)[0] sample = gd[index] # load sample stas = torch.einsum('nlwh,nc->lwhc', sample['stim'], sample['robs'] - sample['robs'].mean(dim=0)) sta = stas.detach().cpu().numpy() # plot STAs / get RF centers """ Plot space/time STAs """ NC = sta.shape[3] mu = np.zeros((NC, 2)) sx = np.ceil(np.sqrt(NC * 2)) sy = np.round(np.sqrt(NC * 2)) mod2 = sy % 2 sy += mod2 sx -= mod2 tdiff = np.zeros((num_lags, NC)) blag = np.zeros(NC) plt.figure(figsize=(sx * 2, sy * 2)) for cc in range(NC): w = sta[:, :, :, cc] wt = np.std(w, axis=0) wt /= np.max(np.abs(wt)) # normalize for numerical stability # softmax wt = wt**10 wt /= np.sum(wt) sz = wt.shape xx, yy = np.meshgrid(np.linspace(-1, 1, sz[1]), np.linspace(1, -1, sz[0])) mu[cc, 0] = np.minimum(np.maximum(np.sum(xx * wt), -.5), .5) # center of mass after softmax mu[cc, 1] = np.minimum(np.maximum(np.sum(yy * wt), -.5), .5) # center of mass after softmax w = (w - np.mean(w)) / np.std(w) bestlag = np.argmax(np.std(w.reshape((gd.num_lags, -1)), axis=1)) blag[cc] = bestlag plt.subplot(sx, sy, cc * 2 + 1) v = np.max(np.abs(w)) plt.imshow(w[bestlag, :, :], aspect='auto', interpolation=None, vmin=-v, vmax=v, cmap="coolwarm", extent=(-1, 1, -1, 1)) plt.plot(mu[cc, 0], mu[cc, 1], '.b') plt.title(cc) plt.subplot(sx, sy, cc * 2 + 2) i, j = np.where(w[bestlag, :, :] == np.max(w[bestlag, :, :])) t1 = w[:, i[0], j[0]] plt.plot(t1, '-ob') i, j = np.where(w[bestlag, :, :] == np.min(w[bestlag, :, :])) t2 = w[:, i[0], j[0]] plt.plot(t2, '-or') yd = plt.ylim() tdiff[:, cc] = t1 - t2 plt.savefig(figDir + "/rawstas_" + gd.id + ".pdf", bbox_inches='tight') #%% Refit DivNorm Model on all datasets """ Fit single layer DivNorm model with modulation """ for version in range(1): # range of version numbers #% Model: convolutional model input_channels = gd.num_lags hidden_channels = 16 input_kern = 19 hidden_kern = 5 core = cores.Stacked2dDivNorm( input_channels, hidden_channels, input_kern, hidden_kern, layers=1, gamma_hidden=1e-6, # group sparsity gamma_input=1, gamma_center=0, skip=0, final_nonlinearity=True, bias=False, pad_input=True, hidden_padding=hidden_kern // 2, group_norm=True, num_groups=4, weight_norm=True, hidden_dilation=1, input_regularizer="RegMats", input_reg_types=["d2x", "center", "d2t"], input_reg_amt=[.000005, .01, 0.00001], hidden_reg_types=["d2x", "center"], hidden_reg_amt=[.000005, .01], stack=None, use_avg_reg=True) # initialize input layer to be centered regw = regularizers.gaussian2d(input_kern, sigma=input_kern // 4) core.features[0].conv.weight.data = torch.einsum( 'ijkm,km->ijkm', core.features[0].conv.weight.data, torch.tensor(regw)) # Readout in_shape = [core.outchannels, gd.NY, gd.NX] bias = True readout = readouts.Point2DGaussian(in_shape, gd.NC, bias, init_mu_range=0.1, init_sigma=1, batch_sample=True, gamma_l1=0, gamma_l2=0.00001, align_corners=True, gauss_type='uncorrelated', constrain_positive=False, shifter={ 'hidden_features': 20, 'hidden_layers': 1, 'final_tanh': False, 'activation': "softplus", 'lengthscale': lengthscale }) modifiers = { 'stimlist': ['frametime', 'sacoff'], 'gain': [sample['frametime'].shape[1], sample['sacoff'].shape[1]], 'offset': [sample['frametime'].shape[1], sample['sacoff'].shape[1]], 'stage': "readout", 'outdims': gd.NC } # combine core and readout into model model = encoders.EncoderMod( core, readout, modifiers=modifiers, gamma_mod=0, weight_decay=.001, optimizer='AdamW', learning_rate= .01, # high initial learning rate because we decay on plateau betas=[.9, .999], amsgrad=False) # initialize readout based on spike rate and STA centers model.readout.bias.data = sample['robs'].mean( dim=0) # initialize readout bias helps model.readout._mu.data[0, :, 0, :] = torch.tensor( mu.astype('float32')) # initiaalize mus #% Train trainer, train_dl, valid_dl = ut.get_trainer(gd, version=version, save_dir=save_dir, name=gd.id, auto_lr=False, batchsize=1000, num_workers=64, earlystopping=True) trainpath = Path(save_dir) / gd.id / "version_{}".format(version) if not trainpath.exists(): trainer.fit(model, train_dl, valid_dl) # Load best version pth = Path(save_dir) / gd.id valloss = np.array([]) ind = np.array([]) for v in pth.rglob('version*'): try: df = pd.read_csv(str(v / "metrics.csv")) ind = np.append(ind, int(v.name.split('_')[1])) valloss = np.append(valloss, np.nanmin(df.val_loss.to_numpy())) except: "skip" sortind = np.argsort(ind) ind = ind[sortind] valloss = valloss[sortind] vernum = ind[np.argmin(valloss)] # load all shifters and save them shifters = nn.ModuleList() for vernum in ind: # load linear model version ver = "version_" + str(int(vernum)) chkpath = pth / ver / 'checkpoints' best_epoch = ut.find_best_epoch(chkpath) model2 = encoders.EncoderMod.load_from_checkpoint( str(chkpath / best_epoch)) shifters.append(model2.readout.shifter.cpu()) # load best model version ver = "version_" + str(int(vernum)) chkpath = pth / ver / 'checkpoints' best_epoch = ut.find_best_epoch(chkpath) model2 = encoders.EncoderMod.load_from_checkpoint(str(chkpath / best_epoch)) shifter = model2.readout.shifter shifter.cpu() print("saving file") outdict = { 'cids': gd.cids, 'shifter': shifter, 'shifters': shifters, 'vernum': ind, 'vallos': valloss, 'numlags': num_lags, 'tdownsample': tdownsample, 'lengthscale': lengthscale } pickle.dump(outdict, open(str(outfile), "wb"))
sample = gd_test[:] l2 = get_null_adjusted_ll(model2, sample) plt.figure() plt.plot(l2, '-o') plt.axhline(0, color='k') #%% Reload dataset in restricted position import V1FreeViewingCode.models.datasets as dd import importlib importlib.reload(dd) gd_restricted = dd.PixelDataset(sessid, stims=["Gabor"], stimset="Train", num_lags=num_lags, downsample_t=t_downsample, downsample_s=1, valid_eye_rad=2, valid_eye_ctr=(0.0,0.0), include_eyepos=True, preload=True) #%% retrain # don't train first layer model2.core.features.layer0.conv.weight.requires_grad = False model2.core.features.layer0.norm.weight.requires_grad = False model2.core.features.layer0.norm.bias.requires_grad = False # train pooling layer model2.core.features.layer1.conv.weight.requires_grad = True # do train readout model2.readout.features.requires_grad = True
valloss = tmp['vallos'] num_lags = tmp['numlags'] t_downsample = tmp['tdownsample'] else: print("foveal_stas: need to run fit_shifter for %s" %sessid) #%% LOAD DATASETS with shifter import V1FreeViewingCode.models.datasets as dd import importlib importlib.reload(dd) gd = dd.PixelDataset(sessid, stims=["Dots", "Gabor"], #, "BackImage", "Grating", "FixRsvpStim"], stimset="Train", num_lags=num_lags, downsample_t=t_downsample, downsample_s=1, valid_eye_rad=5.2, shifter=shifter, include_eyepos=True, preload=True) #%% get sample sample = gd[:] im = sample['stim'].detach().clone() #%% new STAs on shifted stimulus crop = [20, 50, 0, 30] stas = torch.einsum('nlwh,nc->lwhc', im[:,:,crop[2]:crop[3], crop[0]:crop[1]], sample['robs']-sample['robs'].mean(dim=0)) sta = stas.detach().cpu().numpy() #%%
def make_figures( name='20200304_kilowf', path='/home/jake/Data/Datasets/MitchellV1FreeViewing/stim_movies/', tdownsample=2, numlags=12, lengthscale=1, nspace=100): """ make_figures Make STA figures and dump analyses for Figure 5 of Yates et al., 2021 Inputs: gd <Dataset> loaded pixel dataset outdict <dict> meta data from shifter fit (saved by fit_shifter) figDir <str> path to directory to save figures nspace <int> number of spatial positions in shifter plots (default = 100) Output: None Saves a file mat file with all analyses and dumps a pdf for each cell """ print("Running Make Figures") import scipy.io as sio # for saving matlab files figDir = "/home/jake/Data/Repos/V1FreeViewingCode/Figures/2021_pytorchmodeling" save_dir = '../../checkpoints/v1calibration_ls{}'.format(lengthscale) outfile = Path(save_dir) / name / 'best_shifter.p' if outfile.exists(): print("fit_shifter was already run. Loading [%s]" % name) tmp = pickle.load(open(str(outfile), "rb")) cids = tmp['cids'] shifter = tmp['shifter'] shifters = tmp['shifters'] vernum = tmp['vernum'] valloss = tmp['vallos'] num_lags = tmp['numlags'] t_downsample = tmp['tdownsample'] else: print("v1_tracker_calibration: need to run fit_shifter for %s" % name) return n = 40 num_basis = 15 B = np.maximum( 1 - np.abs( np.expand_dims(np.asarray(np.arange(0, n)), axis=1) - np.arange(0, n, n / num_basis)) / n * num_basis, 0) gd = dd.PixelDataset( name, stims=["Dots", "Gabor"], #, "BackImage", "Grating", "FixRsvpStim"], stimset="Train", num_lags=num_lags, downsample_t=t_downsample, downsample_s=1, valid_eye_rad=5.2, include_frametime={ 'num_basis': 40, 'full_experiment': False }, include_saccades=[{ 'name': 'sacon', 'basis': B, 'offset': -20 }, { 'name': 'sacoff', 'basis': B, 'offset': 0 }], include_eyepos=True, preload=True) sample = gd[:] # --- Plot all shifters # build inputs for shifter plotting xax = np.linspace(-gd.valid_eye_rad, gd.valid_eye_rad, nspace) xx, yy = np.meshgrid(xax, xax) xgrid = torch.tensor(xx.astype('float32').reshape((-1, 1))) ygrid = torch.tensor(yy.astype('float32').reshape((-1, 1))) inputs = torch.cat((xgrid, ygrid), dim=1) nshifters = len(shifters) print("Start Plot") fig = plt.figure(figsize=(7, nshifters * 2)) shiftX = [] shiftY = [] for i in range(nshifters): y = shifters[i](inputs) y2 = y.detach().cpu().numpy() y2 /= gd.valid_eye_rad / 60 # convert to arcmin vmin = np.min(y2) vmax = np.max(y2) plt.subplot(nshifters, 2, i * 2 + 1) im = plt.contourf(xax, xax, y2[:, 0].reshape((nspace, nspace)), vmin=vmin, vmax=vmax) if i == 0: plt.title("Horizontal") plt.colorbar(im) plt.subplot(nshifters, 2, i * 2 + 2) im = plt.contourf( xax, xax, y2[:, 1].reshape((nspace, nspace)), vmin=vmin, vmax=vmax ) #, extent=(-gd.valid_eye_rad,gd.valid_eye_rad,-gd.valid_eye_rad,gd.valid_eye_rad), interpolation=None, vmin=vmin, vmax=vmax) if i == 0: plt.title("Vertical") plt.colorbar(im) shiftX.append(y2[:, 0].reshape((nspace, nspace))) shiftY.append(y2[:, 1].reshape((nspace, nspace))) print("Save fig") plt.savefig(figDir + "/shifters_" + gd.id + ".pdf", bbox_inches='tight') print("Success") # calculate mean and standard deviation across shifters Xarray = np.asarray(shiftX) Yarray = np.asarray(shiftY) mux = Xarray.mean(axis=0) sdx = Xarray.std(axis=0) muy = Yarray.mean(axis=0) sdy = Yarray.std(axis=0) print("Select Stim") # select stimulus if 'Dots' in gd.stims: stimuse = [ i for i, s in zip(range(len(gd.stims)), gd.stims) if 'Dots' == s ][0] elif 'Gabor' in gd.stims: stimuse = [ i for i, s in zip(range(len(gd.stims)), gd.stims) if 'Gabor' == s ][0] index = np.where(gd.stim_indices == stimuse)[0] sample = gd[index] # load sample # use best model bestver = np.argmin(valloss) shift = shifters[bestver](sample['eyepos']).detach() y = shifters[bestver](inputs) y2 = y.detach().cpu().numpy() print("Shift Stim") # shift stimulus im = sample['stim'] #.detach().clone() # original im2 = shift_stim(im, shift, gd) # shifted print("Get STAs") # compute new STAs on shifted stimulus stas = torch.einsum('nlwh,nc->lwhc', im, sample['robs'] - sample['robs'].mean(dim=0)) sta = stas.detach().cpu().numpy() stas = torch.einsum('nlwh,nc->lwhc', im2, sample['robs'] - sample['robs'].mean(dim=0)) sta2 = stas.detach().cpu().numpy() print("Save Mat File") # Save output for matlab fname = figDir + "/rfs_" + gd.id + ".mat" mdict = { 'cids': cids, 'xspace': xx, 'yspace': yy, 'shiftx': y2[:, 0].reshape((100, 100)), 'shifty': y2[:, 1].reshape((100, 100)), 'stas_pre': sta, 'stas_post': sta2, 'valloss': valloss, 'mushiftx': mux, 'mushifty': muy, 'sdshiftx': sdx, 'sdshifty': sdy } sio.savemat(fname, mdict) # compute dimensions extent = np.round(gd.rect / gd.ppd * 60) extent = np.asarray([extent[i] for i in [0, 2, 3, 1]]) print("Plot STAs") # Plot STAs Before and After NC = sta.shape[3] # Loop over cells for cc in range(NC): plt.figure(figsize=(8, 2)) w = sta[:, :, :, cc] w2 = sta2[:, :, :, cc] bestlag = np.argmax(np.std(w2.reshape((gd.num_lags, -1)), axis=1)) w = (w - np.mean(w[bestlag, :, :])) / np.std(w) # before w2 = (w2 - np.mean(w2[bestlag, :, :])) / np.std(w2) # after plt.subplot(1, 4, 3) # After Space v = np.max(np.abs(w2)) plt.imshow(w2[bestlag, :, :], aspect='auto', interpolation=None, vmin=-v, vmax=v, cmap="coolwarm", extent=extent) plt.title("After") plt.xlabel("arcmin") plt.ylabel("arcmin") plt.subplot(1, 4, 4) # After Time i, j = np.where(w2[bestlag, :, :] == np.max(w2[bestlag, :, :])) plt.plot(w2[:, i[0], j[0]], '-ob') i, j = np.where(w2[bestlag, :, :] == np.min(w2[bestlag, :, :])) plt.plot(w2[:, i[0], j[0]], '-or') yd = plt.ylim() plt.xlabel("Lag (frame=8ms)") plt.subplot(1, 4, 1) # Before Space plt.imshow(w[bestlag, :, :], aspect='auto', interpolation=None, vmin=-v, vmax=v, cmap="coolwarm", extent=extent) plt.title("Before") plt.xlabel("arcmin") plt.ylabel("arcmin") plt.subplot(1, 4, 2) # Before Time i, j = np.where(w[bestlag, :, :] == np.max(w[bestlag, :, :])) plt.plot(w[:, i[0], j[0]], '-ob') i, j = np.where(w[bestlag, :, :] == np.min(w[bestlag, :, :])) plt.plot(w[:, i[0], j[0]], '-or') plt.axvline(bestlag, color='k', linestyle='--') plt.ylim(yd) plt.xlabel("Lag (frame=8ms)") plt.savefig(figDir + "/sta_shift" + gd.id + "_" + str(cc) + ".pdf", bbox_inches='tight') plt.close('all') print("Save Mat File") # save all STAs as one figure plot_sta_fig(sta, gd) # Before plt.savefig(figDir + "/rawstas_" + gd.id + ".pdf", bbox_inches='tight') plot_sta_fig(sta2, gd) # After plt.savefig(figDir + "/shiftstas_" + gd.id + ".pdf", bbox_inches='tight') plt.close('all')
np.expand_dims(np.asarray(np.arange(0, n)), axis=1) - np.arange(0, n, n / num_basis)) / n * num_basis, 0) t_downsample = tdownsample gd = dd.PixelDataset(sessid, stims=stimlist, stimset="Train", num_lags=num_lags, downsample_t=t_downsample, downsample_s=1, valid_eye_rad=5.2, dirname=path, include_frametime={ 'num_basis': 40, 'full_experiment': False }, include_saccades=[{ 'name': 'sacon', 'basis': B, 'offset': -20 }, { 'name': 'sacoff', 'basis': B, 'offset': 0 }], include_eyepos=True, preload=True) sample = gd[:] #%% compute STAS
importlib.reload(dd) cropidx = cropidxs[sessid] n = 40 num_basis = 15 B = np.maximum(1 - np.abs(np.expand_dims(np.asarray(np.arange(0,n)), axis=1) - np.arange(0,n,n/num_basis))/n*num_basis, 0) # Load Training / Evaluation gd = dd.PixelDataset(sessid, stims=["Dots", "Gabor", "BackImage", "Grating", "FixRsvpStim"], stimset="Train", num_lags=num_lags, downsample_t=t_downsample, downsample_s=1, valid_eye_rad=5.2, shifter=shifters[np.argmin(valloss)], cropidx=cropidx, include_frametime={'num_basis': 40, 'full_experiment': False}, include_saccades=[{'name':'sacon', 'basis':B, 'offset':-20}, {'name':'sacoff', 'basis':B, 'offset':0}], include_eyepos=True, optics={'type': 'gausspsf', 'sigma': (0.7, 0.7, 0.0)}, preload=True) #%% Load test set # print("Loading Test set") # gd_test = dd.PixelDataset(sessid, stims=["Dots", "Gabor", "BackImage"], #, "Grating", "FixRsvpStim"], # stimset="Test", num_lags=num_lags, # downsample_t=t_downsample, # downsample_s=1, # valid_eye_rad=5.2, # cropidx=cropidx, # shifter=shifter,
#%% LOAD ALL DATASETS # build tent basis for saccades n = 40 num_basis = 15 B = np.maximum(1 - np.abs(np.expand_dims(np.asarray(np.arange(0,n)), axis=1) - np.arange(0,n,n/num_basis))/n*num_basis, 0) import V1FreeViewingCode.models.datasets as dd import importlib importlib.reload(dd) gd = dd.PixelDataset(sessid, stims=["Dots", "Gabor"], #, "BackImage", "Grating", "FixRsvpStim"], stimset="Train", num_lags=num_lags, downsample_t=t_downsample, downsample_s=1, valid_eye_rad=5.2, include_frametime={'num_basis': 40, 'full_experiment': False}, include_saccades=[{'name':'sacon', 'basis':B, 'offset':-20}, {'name':'sacoff', 'basis':B, 'offset':0}], include_eyepos=True, preload=True) sample = gd[:] # from V1FreeViewingCode.Analysis.manuscript_freeviewingmethods.v1_tracker_calibration import make_figures # %% Plot all shifters # build inputs for shifter plotting nspace = 100
l2 = get_null_adjusted_ll(model2, sample) plt.figure() plt.plot(l2, '-o') plt.axhline(0, color='k') #%% Reload dataset in restricted position import V1FreeViewingCode.models.datasets as dd import importlib importlib.reload(dd) gd_shift = dd.PixelDataset(sessid, stims=["Gabor"], stimset="Train", num_lags=num_lags, downsample_t=t_downsample, downsample_s=1, valid_eye_rad=5.2, valid_eye_ctr=(0.0,0.0), include_eyepos=True, cropidx=[(15,50),(20,50)], shifter=model2.readout.shifter, preload=True) # %% reload sample and compute STAs sample = gd_shift[:] # load sample stas = torch.einsum('nlwh,nc->lwhc', sample['stim'], (sample['robs']-sample['robs'].mean(dim=0))/sample['robs'].sum(dim=0)) sta = stas.detach().cpu().numpy() #%% plot STAs / get RF centers """ Plot space/time STAs """ NC = sta.shape[3]
import V1FreeViewingCode.models.datasets as dd import importlib importlib.reload(dd) t_downsample = 1 num_lags = 20 gd_shift = dd.PixelDataset( sessid, stims=["Gabor", "Grating", "BackImage"], stimset="Train", num_lags=num_lags, downsample_t=t_downsample, downsample_s=1, valid_eye_rad=5, valid_eye_ctr=(0.0, 0.0), include_eyepos=True, cropidx=win, cids=cids, shifter=shifter, #model2.readout.shifter, preload=True, temporal=True) #%% test set gab_shift_test = dd.PixelDataset(sessid, stims=["Gabor"], stimset="Test", num_lags=num_lags, downsample_t=t_downsample, downsample_s=1,