def plot_tfilters(ndnmod, kts=None, ffnet=0, to_plot=True): """Can pass in weights to relevant layer in first argument, as well as NDN model. Will use default tkerns variable, but this can also be passed in as kts argument.""" assert kts is not None, 'Must include tkerns.' ntk = kts.shape[1] if type( ndnmod ) is np.ndarray: # then passing in filters and need more information ws = deepcopy(ndnmod) else: ws = deepcopy(ndnmod.networks[ffnet].layers[0].weights) if len(ws.shape) > 2: nx, ntk2, numcells = ws.shape ws2 = ws else: nx = ws.shape[0] // ntk numcells = ws.shape[1] ws2 = np.reshape(ws, [nx, ntk, numcells]) ks = np.expand_dims(kts, axis=0) @ ws2 if to_plot: DU.plot_filters(filters=ks) else: return ks
def compute_binocular_tfilters(binoc_mod, kts=None, to_plot=True): assert kts is not None, 'Must include tkerns.' BFs = compute_binocular_filters(binoc_mod, to_plot=False) Bks = np.transpose(np.tensordot(kts, BFs, axes=[1, 0]), (1, 0, 2)) if to_plot: DU.plot_filters(filters=Bks) else: return Bks
def compute_binocular_filters(binoc_mod, to_plot=True): # Find binocular layer blayer, bnet = None, None for mm in range(len(binoc_mod.networks)): for nn in range(len(binoc_mod.networks[mm].layers)): if binoc_mod.network_list[mm]['layer_types'][nn] == 'biconv': if nn < len(binoc_mod.networks[mm].layers) - 1: bnet, blayer = mm, nn + 1 elif mm < len(binoc_mod.networks) - 1: bnet, blayer = mm + 1, 0 # split in hierarchical network assert blayer is not None, 'biconv layer not found' NF = binoc_mod.networks[0].layers[blayer].output_dims[0] Nin = binoc_mod.networks[0].layers[blayer].input_dims[0] NX = binoc_mod.networks[0].layers[blayer].filter_dims[1] ks1 = DU.compute_spatiotemporal_filters(binoc_mod) ws = np.reshape(binoc_mod.networks[0].layers[blayer].weights, [NX, Nin, NF]) num_lags = binoc_mod.networks[0].layers[0].input_dims[0] if binoc_mod.networks[0].layers[0].filter_dims[ 1] > 1: # then not temporal layer filter_dims = [ num_lags, binoc_mod.networks[0].layers[0].filter_dims[1] ] else: filter_dims = [ num_lags, binoc_mod.networks[0].layers[1].filter_dims[1] ] nfd = [filter_dims[0], filter_dims[1] + NX] # print(filter_dims, nfd) Bfilts = np.zeros(nfd + [NF, 2]) for nn in range(NX): Bfilts[:, np.add(range(filter_dims[1]), nn), :, 0] += np.reshape(np.matmul(ks1, ws[nn, range(Nin // 2), :]), [filter_dims[1], filter_dims[0], NF]) Bfilts[:, np.add(range(filter_dims[1]), nn), :, 1] += np.reshape( np.matmul(ks1, ws[nn, range(Nin // 2, Nin), :]), [filter_dims[1], filter_dims[0], NF]) bifilts = np.concatenate((Bfilts[:, :, :, 0], Bfilts[:, :, :, 1]), axis=1) if to_plot: DU.plot_filters(filters=bifilts, flipxy=True) else: return bifilts
_ = baseglm.train(input_data=[Time, Xcon], output_data=Robs, train_indxs=Ui, test_indxs=Xi, learning_alg=optimizer, opt_params=opt_params, use_dropout=False, fit_variables=v2f0) #%% Find best Regularization reg_results = DU.unit_reg_test(baseglm, input_data=[Time, Xcon], output_data=Robs, train_indxs=Ui, test_indxs=Xi, reg_type='d2t', reg_vals=[1e-6, 1e-4, 1e-3, 1e-2, 0.1, 1], layer_targets=[0], ffnet_targets=[1], learning_alg='lbfgs', opt_params=lbfgs_params) baseglm = DU.unit_assign_reg(baseglm, reg_results) #%% f = plt.plot(baseglm.networks[1].layers[0].weights) plt.figure(figsize=(10, 5)) f = plt.plot(Time @ baseglm.networks[0].layers[0].weights) #%% Stim GLM
Ic = frame[y0, x1, :] Id = frame[y1, x1, :] out = Ia for i in range(C): out[:, :, i] = wa * Ia[:, :, i] + wb * Ib[:, :, i] + wc * Ic[:, :, i] + wd * Id[:, :, i] return out # %% I = DU.gabor_sized(30, 90) plt.imshow(I) plt.title("Default image") outsize = np.array((20, 20)) translation = np.array((0, 0)) theta = 10 # # %% cropping an image # for i in np.arange(-10,10,2): # plt.figure() # Ic = roi_crop(I, np.array( (50,50)), np.array( (30,30)),theta=i) # plt.imshow(Ic) # %% cropping /rotating a vector (flattened image)
# glm0.networks[0].layers[0].biases = np.mean(Rvalid,axis=0).astype('float32') v2f0 = glm0.fit_variables(fit_biases=False) v2f0[-1][-1]['biases'] = True # train initial model _ = glm0.train(input_data=[Xstim], output_data=Rvalid, train_indxs=Ui, test_indxs=Xi, learning_alg='lbfgs', opt_params=lbfgs_params, fit_variables=v2f0) #%% plot filters DU.plot_3dfilters(glm0) LLx0 = glm0.eval_models(input_data=[Xstim], output_data=Rvalid, data_indxs=Xi, nulladjusted=True) plt.plot(LLx0, '-o') plt.axhline(0) # %% get crop indices # Cxinds = ne.crop_indx(NX, range(1,30), range(1,30)) Cxinds = ne.crop_indx(NX, range(9, 24), range(9, 24)) # Cxinds = ne.crop_indx(NX, range(5,20), range(5,20)) # Cxinds = ne.crop_indx(NX, range(20,44), range(20,44)) NX2 = np.sqrt(len(Cxinds)).astype(int)
def compute_binocular_filters(binoc_mod, ffnet_n=0, to_plot=True, num_space=36): """using standard binocular model, compute filters. defaults to first ffnet and num_space = 36. Set num_space=None to go to minimum given convolutional constraints""" # Find binocular layer blayer, bnet = None, None for mm in range(len(binoc_mod.networks)): for nn in range(len(binoc_mod.networks[mm].layers)): if binoc_mod.network_list[mm]['layer_types'][nn] == 'biconv': if nn < len(binoc_mod.networks[mm].layers) - 1: bnet, blayer = mm, nn + 1 elif mm < len(binoc_mod.networks) - 1: bnet, blayer = mm + 1, 0 # split in hierarchical network assert blayer is not None, 'biconv layer not found' NF = binoc_mod.networks[ffnet_n].layers[blayer].output_dims[0] Nin = binoc_mod.networks[ffnet_n].layers[blayer].input_dims[0] NX = binoc_mod.networks[ffnet_n].layers[blayer].filter_dims[1] ks1 = DU.compute_spatiotemporal_filters(binoc_mod) ws = np.reshape(binoc_mod.networks[ffnet_n].layers[blayer].weights, [NX, Nin, NF]) num_lags = binoc_mod.networks[ffnet_n].layers[0].input_dims[0] if binoc_mod.networks[ffnet_n].layers[0].filter_dims[ 1] > 1: # then not temporal layer filter_dims = [ num_lags, binoc_mod.networks[0].layers[0].filter_dims[1] ] else: filter_dims = [ num_lags, binoc_mod.networks[0].layers[1].filter_dims[1] ] num_cspace = filter_dims[1] + NX nfd = [filter_dims[0], num_cspace] # print(filter_dims, nfd) Bfilts = np.zeros(nfd + [NF, 2]) for nn in range(NX): Bfilts[:, np.add(range(filter_dims[1]), nn), :, 0] += np.reshape(np.matmul(ks1, ws[nn, range(Nin // 2), :]), [filter_dims[1], filter_dims[0], NF]) Bfilts[:, np.add(range(filter_dims[1]), nn), :, 1] += np.reshape( np.matmul(ks1, ws[nn, range(Nin // 2, Nin), :]), [filter_dims[1], filter_dims[0], NF]) # Cast into desired num_space if num_space is None: num_space = num_cspace if num_space == num_cspace: BfiltsX = Bfilts elif num_space > num_cspace: BfiltsX = np.zeros([filter_dims[0], num_space, NF, 2]) padding = (num_space - num_cspace) // 2 BfiltsX[:, padding + np.arange(num_cspace), :, :] = Bfilts else: # crop unpadding = (num_cspace - num_space) // 2 BfiltsX = Bfilts[:, unpadding + np.arange(num_space), :, :] bifilts = np.concatenate((BfiltsX[:, :, :, 0], BfiltsX[:, :, :, 1]), axis=1) if to_plot: DU.plot_filters(filters=bifilts, flipxy=True) else: return bifilts
def disparity_tuning(Einfo, r, used_inds=None, num_dlags=8, fr1or3=3, to_plot=False): if used_inds is None: used_inds = range(len(r)) dmat = disparity_matrix(Einfo['dispt'], Einfo['corrt']) ND = (dmat.shape[1] - 2) // 2 # Weight all by their frequency of occurance if (fr1or3 == 3) or (fr1or3 == 1): frs_valid = Einfo['frs'] == fr1or3 else: frs_valid = Einfo['frs'] > 0 to_use = frs_valid[used_inds] #dmatN = dmat / np.mean(dmat[used_inds[to_use],:], axis=0) * np.mean(dmat[used_inds[to_use],:]) dmatN = dmat / np.mean(dmat[used_inds[to_use], :], axis=0) # will be stim rate # if every stim resulted in 1 spk, the would be 1 as is #nrms = np.mean(dmat[used_inds[to_use],:], axis=0) # number of stimuli of each type Xmat = NDNutils.create_time_embedding(dmatN[:, range(ND * 2)], [num_dlags, 2 * ND, 1])[used_inds, :] # uncorrelated response Umat = NDNutils.create_time_embedding(dmatN[:, [-2]], [num_dlags, 1, 1])[used_inds, :] #if len(r) > len(used_inds): resp = deepcopy(r[used_inds]) #else: # resp = r #Nspks = np.sum(resp[to_use, :], axis=0) Nspks = len( to_use ) # this will end up being number of spikes associated with each stim # at different lags, divided by number of time points used. (i.e. prob of spike per bin) Dsta = np.reshape(Xmat[to_use, :].T @ resp[to_use], [2 * ND, num_dlags]) / Nspks Usta = (Umat[to_use, :].T @ resp[to_use])[:, 0] / Nspks # Rudimentary analysis best_lag = np.argmax(np.max(Dsta[range(ND), :], axis=0)) Dtun = np.reshape(Dsta[:, best_lag], [2, ND]).T uncor_resp = Usta[best_lag] Dinfo = { 'Dsta': Dsta, 'Dtun': Dtun, 'uncor_resp': uncor_resp, 'best_lag': best_lag, 'uncor_sta': Usta, 'disp_list': Einfo['disp_list'][2:] } if to_plot: DU.subplot_setup(1, 2) plt.subplot(1, 2, 1) DU.plot_norm(Dsta.T - uncor_resp, cmap='bwr') plt.plot([ND - 0.5, ND - 0.5], [-0.5, num_dlags - 0.5], 'k') plt.plot([-0.5, 2 * ND - 0.5], [best_lag, best_lag], 'k--') plt.subplot(1, 2, 2) plt.plot(Dtun) plt.plot(-Dtun[:, 1] + 2 * uncor_resp, 'm--') plt.plot([0, ND - 1], [uncor_resp, uncor_resp], 'k') plt.xlim([0, ND - 1]) plt.show() return Dinfo
v2f = retV1.fit_variables(fit_biases=False) v2f[0][0]['biases'] = True v2f[-1][-1]['biases'] = True #%% train _ = retV1.train(input_data=[Xstim], output_data=Robs, train_indxs=Ui, test_indxs=Xi, silent=False, learning_alg='adam', opt_params=adam_params, fit_variables=v2f) # %% fit DU.plot_3dfilters(retV1) # ====================================================================== # ====================================================================== # ====================================================================== # ====================================================================== # STRAY CODE BELOW HERE # ====================================================================== # ====================================================================== # ====================================================================== #%% #%%
def disparity_tuning(Einfo, r, used_inds=None, num_dlags=8, fr1or3=3, to_plot=False): if used_inds is None: used_inds = range(len(r)) dmat = disparity_matrix(Einfo['dispt'], Einfo['corrt']) ND = (dmat.shape[1] - 2) // 2 # Weight all by their frequency of occurance if (fr1or3 == 3) or (fr1or3 == 1): frs_valid = Einfo['frs'] == fr1or3 else: frs_valid = Einfo['frs'] > 0 to_use = frs_valid[used_inds] dmatN = dmat / np.mean(dmat[used_inds[to_use], :], axis=0) * np.mean( dmat[used_inds[to_use], :]) Xmat = NDNutils.create_time_embedding(dmatN[:, range(ND * 2)], [num_dlags, 2 * ND, 1])[used_inds, :] # uncorrelated response Umat = NDNutils.create_time_embedding(dmatN[:, [-2]], [num_dlags, 1, 1])[used_inds, :] #if len(r) > len(used_inds): resp = deepcopy(r[used_inds]) #else: # resp = r Nspks = np.sum(resp[to_use, :], axis=0) Dsta = np.reshape(Xmat[to_use, :].T @ resp[to_use], [2 * ND, num_dlags]) / Nspks Usta = (Umat[to_use, :].T @ resp[to_use])[:, 0] / Nspks # Rudimentary analysis best_lag = np.argmax(np.max(Dsta[range(ND), :], axis=0)) Dtun = np.reshape(Dsta[:, best_lag], [2, ND]).T uncor_resp = Usta[best_lag] Dinfo = { 'Dsta': Dsta, 'Dtun': Dtun, 'uncor_resp': uncor_resp, 'best_lag': best_lag, 'uncor_sta': Usta } if to_plot: DU.subplot_setup(1, 2) plt.subplot(1, 2, 1) DU.plot_norm(Dsta.T - uncor_resp, cmap='bwr') plt.plot([ND - 0.5, ND - 0.5], [-0.5, num_dlags - 0.5], 'k') plt.plot([-0.5, 2 * ND - 0.5], [best_lag, best_lag], 'k--') plt.subplot(1, 2, 2) plt.plot(Dtun) plt.plot(-Dtun[:, 1] + 2 * uncor_resp, 'm--') plt.plot([0, ND - 1], [uncor_resp, uncor_resp], 'k') plt.xlim([0, ND - 1]) plt.show() return Dinfo
train_indxs=Ui, test_indxs=Xi, learning_alg=optimizer, opt_params=opt_params, use_dropout=False) # %% evaluate models Ti = opts['Ti'] LLx0 = glm.eval_models(input_data=[Xstim], output_data=Robs, data_indxs=Ti, nulladjusted=null_adjusted) print(LLx0) # %% plot learned RFs filters = DU.compute_spatiotemporal_filters(glm) gt.plot_3dfilters(filters, basis=basis) # %% Rpred0 = glm.generate_prediction(input_data=[Xstim]) #%% # cc +=1 cc = 26 Ti = opts['Ti'] r = np.reshape(Robs[Ti, cc], (opts['num_repeats'], -1)) r0 = np.reshape(Rpred0[Ti, cc], (opts['num_repeats'], -1)) r = np.average(r, axis=0) r0 = np.average(r0, axis=0) plt.plot(r) plt.plot(r0)
act_funcs=['relu', 'relu'], verbose=True, reg_list={'d2x':[XTreg], 'l1':[L1reg0, L1reg0], 'glocal':[Greg0]}) side_par = NDNutils.ffnetwork_params( network_type='side', xstim_n=None, ffnet_n=1, layer_sizes=[NC], layer_types=['normal'], normalization=[-1], act_funcs=['softplus'], verbose=True, reg_list={'max':[Mreg0]}) side_par['pos_constraints']=True side2 = NDN.NDN( [t_layer, ndn_par, side_par], ffnet_out=2, noise_dist='poisson') #%% gab_array = DU.gabor_array(NX2//2, num_angles=num_subs//2, both_phases=True) side2.networks[1].layers[0].weights = deepcopy(gab_array) input_data = stmp NumBlocks = blocks.shape[0] bad_blocks = np.where((blocks[:,1]-blocks[:,0]) < 10)[0] good_blocks = np.setdiff1d(np.arange(0, NumBlocks-1), bad_blocks) blocks = blocks[good_blocks,:] NumBlocks = blocks.shape[0] Ui,Xi = NDNutils.generate_xv_folds(NumBlocks) _ = side2.train(input_data=input_data, output_data=Rvalid, train_indxs=Ui, test_indxs=Xi, silent=False, learning_alg='adam', opt_params=adam_params, blocks=blocks+1) # adjust regularization and re-train
LLx.append(LLx0) #%% compare two models plt.figure() modi = 1 modj = 6 plt.plot(LLx[modi], LLx[modj], '.') plt.plot(plt.xlim(), plt.xlim(), 'k') plt.xlabel(names[modi]) plt.ylabel(names[modj]) # %% plot learned RFs i = 6 print(names[i]) filters = DU.compute_spatiotemporal_filters(ndns[i]) gt.plot_3dfilters(filters, basis=basis) # %% plot gain and offset modj = 6 xax = np.arange(-back_shifts, num_saclags - back_shifts, 1) ix = LLx[modj] > 0.5 plt.figure(figsize=(5, 4)) if len(ndns[modj].networks) > 3: plt.subplot(1, 2, 2) f = plt.plot(xax, ndns[modj].networks[2].layers[0].weights[:, ix], '#3ed8e6') plt.subplot(1, 2, 1) f = plt.plot(xax, ndns[modj].networks[1].layers[0].weights[:, ix], '#3ed8e6')
stas = (Xstim.T @ (Rvalid-np.mean(Rvalid, axis=0))) / np.sum(Rvalid, axis=0) stas /= np.sum(stas,axis=0) gqm0.networks[0].layers[0].weights[:] = deepcopy(stas[:]) gqm0.set_regularization('d2x', reg_val=bestreg, ffnet_target=0) gqm0.set_regularization('d2x', reg_val=bestreg, ffnet_target=1) gqm0.set_regularization('d2x', reg_val=bestreg, ffnet_target=2) # train initial model _ = gqm0.train(input_data=[Xstim], output_data=Rvalid, train_indxs=Ui, test_indxs=Xi, learning_alg='adam', opt_params=adam_params, fit_variables=v2f0) # f= plt.plot(np.asarray(LLxs)) #%% plot model DU.plot_3dfilters(gqm0, ffnet=0) DU.plot_3dfilters(gqm0, ffnet=1) DU.plot_3dfilters(gqm0, ffnet=2) #%% plot fit LLx = gqm0.eval_models(input_data=Xstim, output_data=Rvalid, data_indxs=Xi, nulladjusted=True) plt.figure() plt.plot(LLx, '-o') plt.axhline(0, color='k') #%% Run eye correction # Run it all once eyeAtFrameCentered = (eyeAtFrame-(640, 380)) centers5, locs, LLspace1 = ne.get_corr_grid(gqm0, Stim, Robs, [NX,NY], cids,
v2f = nim0.fit_variables(fit_biases=True) # train _ = nim0.train(input_data=[Xstim], output_data=Robs, train_indxs=Ui, test_indxs=Xi, learning_alg=optimizer, opt_params=opt_params, use_dropout=False, fit_variables=v2f) print("Done") DU.plot_3dfilters(nim0) # only include cells that are well fit LLx = nim0.eval_models(input_data=[Xstim], output_data=Robs, data_indxs=Xi, nulladjusted=True) plt.plot(LLx, 'o') cids = np.where(LLx > 0.05)[0] NC = len(cids) Robsv = deepcopy(Robs[:, cids]) num_subs = NC // 2 nim_par = NDNutils.ffnetwork_params(input_dims=[1, NX, NY, num_lags],
# sta = (sta - np.min(sta)) / (np.max(sta) - np.min(sta)) # glm0.networks[0].layers[0].weights[:,0]=deepcopy((sta - np.min(sta)) / (np.max(sta) - np.min(sta))) v2f0 = glm0.fit_variables(fit_biases=True) # train initial model _ = glm0.train(input_data=[Xstim], output_data=Robs, train_indxs=Ui, test_indxs=Xi, learning_alg='lbfgs', opt_params=lbfgs_params, fit_variables=v2f0) # plot filters DU.plot_3dfilters(glm0) # Find best regularization glmbest = glm0.copy_model() [LLpath, glms] = NDNutils.reg_path(glmbest, input_data=[Xstim], output_data=Robs, train_indxs=Ui, test_indxs=Xi, reg_type='glocal', reg_vals=[1e-6, 1e-4, 1e-3, 1e-2, 0.1, 1], layer_target=0, ffnet_target=0, learning_alg='lbfgs',
v2f = retV1.fit_variables(fit_biases=True) v2f[0][0]['biases'] = True v2f[-1][0]['biases'] = True #%% train _ = retV1b.train(input_data=[Xstim, Rvalid], output_data=Rvalid, train_indxs=Ui, test_indxs=Xi, silent=False, learning_alg='adam', opt_params=adam_params, fit_variables=v2fb) # %% fit DU.plot_3dfilters(retV1b) #%% plt.figure(figsize=(10, 4)) plt.subplot(1, 2, 1) plt.plot(retV1b.networks[0].layers[0].weights) plt.title("Temporal Kernels") plt.subplot(1, 2, 2) plt.plot(retV1b.networks[1].layers[1].weights) plt.title("Latent temporal kernel") #%% get test likelihood LLx1 = retV1b.eval_models(input_data=[Xstim, Rvalid], output_data=Rvalid, data_indxs=Xi, nulladjusted=True, use_gpu=False)