def save_redshift_samples(th_samps, ll_samps, q_idx, K, V, qso_info, chain_idx, use_mle): """ save basis fit info """ #dump separately - pickle is super inefficient if chain_idx == "temper": fbase = 'cache/photo_z_samps/redshift_samples_K-%d_V-%d_qso_%d_chain_%s_mle_%r'%(K, V, q_idx, chain_idx, use_mle) else: fbase = 'cache/photo_z_samps/redshift_samples_K-%d_V-%d_qso_%d_chain_%d_mle_%r'%(K, V, q_idx, chain_idx, use_mle) np.save(fbase + '.npy', th_samps) with open(fbase + '.pkl', 'wb') as handle: pickle.dump(ll_samps, handle) pickle.dump(q_idx, handle) pickle.dump(qso_info, handle) pickle.dump(chain_idx, handle)
def save_basis_samples(th_samps, ll_samps, lam0, lam0_delta, parser, chain_idx): """ save basis fit info """ # grab B value for shape info B = parser.get(th_samps[0,:], 'betas') #dump separately - pickle is super inefficient fbase = 'cache/basis_samples_K-%d_V-%d_chain_%d'%(B.shape[0], B.shape[1], chain_idx) np.save(fbase + '.npy', th_samps) with open(fbase + '.pkl', 'wb') as handle: pickle.dump(ll_samps, handle) pickle.dump(lam0, handle) pickle.dump(lam0_delta, handle) pickle.dump(parser, handle) pickle.dump(chain_idx, handle)
def adam(objFunc, u, V, R): ''' Adam Stochastic Gradient Descent ''' m1 = 0 m2 = 0 beta1 = 0.9 beta2 = 0.999 epsilon = 1e-8 t = 0 learning_rate = 0.001 dfunc = grad(objFunc) old_norm = 100 grady = [] for i in range(1000): t+=1 params = stack_params(u, V, R) gradients = dfunc(params) norm = np.linalg.norm(gradients) grady.append(np.linalg.norm(gradients)) m1 = beta1 * m1 + (1 - beta1) * gradients m2 = beta2 * m2 + (1 - beta2) * gradients**2 m1_hat = m1 / (1 - beta1**t) m2_hat = m2 / (1 - beta2**t) delta = learning_rate*m1_hat/(np.sqrt(m2_hat) + epsilon) du, dV, dR = unstack_params(delta) u = u - du V = V - dV #R = R - dR if np.linalg.norm(old_norm - norm) < 0.0001: break old_norm = norm grady = np.array(grady) np.save('adam_gradients', grady) return u, V, R
def fit(self, n_iter): old_params = util.stack_params(self.u, self.V, self.R) x_test = util.generate_data(5, self.W, self.sigx, self.dimx, self.dimz)[0] llh = [] print "True: ", self.true_marg(x_test) def objective_wrapper(params, x): u, V, R = util.unstack_params(params) return self.objective(x, R, u, V) for j in xrange(n_iter): for i in xrange(self.n): x = self.observed[i] obj_local = lambda param: objective_wrapper(param, x) #u, V, R = self.gradient_descent(obj_local, self.u, self.V, self.R, x) u, V, R = util.adam(obj_local, self.u, self.V, self.R) #u, V, R = util.optimizer(obj_local, self.u, self.V, self.R) #print self.objective(x, R,self.u,self.V) #temp a = 1/float(self.n) self.u = np.dot(self.V, self.u)*(1- a) + a*(np.dot(V,u)) self.V = self.V + a*(V- self.V) self.u = np.dot( np.linalg.inv(self.V), self.u) #RM update a2 = 1/float(self.n) self.R = (1-a2)*self.R + a2*R #print "obj: ", self.objective(x, R,self.u,self.V) params = util.stack_params(self.u, self.V, self.R) diff = np.linalg.norm(old_params - params) old_params = params ll = self.get_marginal(self.u, self.V, self.R, x_test) llh.append(ll) print "m: ", ll, diff, "iter: ", j np.save('DATA50testNr2', np.array(llh)) np.save('DATA50testNr2', np.array(llh))
def fit(self, n_iter): old_params = util.stack_params(self.u, self.V, self.R) x_test = util.generate_data(20, self.W, self.sigx, self.dimx, self.dimz)[0] llh = [] for j in xrange(n_iter): for i in xrange(self.n): x = self.observed[i] obj_local = lambda param: self.objective_wrapper(param, x) u, V, R = util.gradient_descent(obj_local, self.u, self.V, self.R) #u, V, R = util.optimizer(obj_local, self.u, self.V, self.R) #DO SEP updates-- damping- a is alpha - see SEP paper a = 1/float(self.n) self.u = np.dot(self.V, self.u)*(1- a) + a*(np.dot(V,u)) self.V = self.V + a*(V- self.V) self.u = np.dot( np.linalg.inv(self.V), self.u) #RM update a2 = 1/float(self.n) self.R = (1-a2)*self.R + a2*R params = util.stack_params(self.u, self.V, self.R) diff = np.linalg.norm(old_params - params) old_params = params ll = self.get_marginal(self.u, self.V, self.R, x_test) llh.append(ll) print "m: ", ll, diff np.save('learn', np.array(llh)) np.save('learn', np.array(llh))
plt.figure() plt.plot(mean_list_i[0, :], 'r-', label='Chain 1') plt.plot(mean_list_i[1, :], 'g-', label='Chain 2') plt.plot(mean_list_i[2, :], 'm-', label='Chain 3') plt.plot(mean_list_i[3, :], 'b-', label='Chain 4') plt.plot(np.repeat(posterior_mean[0], len(mean_list_i[1, :])), 'k-', label='True posterior') plt.ylim((3, 6)) plt.xlabel('Iterations') plt.ylabel('Means') plt.legend() plt.savefig('plots/vi_chains_mean_MF.pdf') np.save('mean_0_chains_MF', mean_list_i) print(mean_list_i) plt.figure() plt.plot(sigmas_list_i[0, :], 'r-', label='Chain 1') plt.plot(sigmas_list_i[1, :], 'g-', label='Chain 2') plt.plot(sigmas_list_i[2, :], 'm-', label='Chain 3') plt.plot(sigmas_list_i[3, :], 'b-', label='Chain 4') plt.xlabel('Iterations') plt.ylabel('Variances') plt.plot(np.repeat(np.sqrt(posterior_variance[0, 0]), len(sigmas_list_i[1, :])), 'k-', label='True posterior') plt.ylim((-0.2, 1.2))
# Calculate the parameters of the ADAM optimizer m = beta_1*m + (1-beta_1)*grad v = beta_2*v + (1-beta_2)*(grad*grad) m_hat = m/(1-(beta_1**t)) v_hat = v/(1-(beta_2**t)) # Update the parameters using ADAM optimizer flattened_current_params = flattened_current_params + (alpha*m_hat)/(np.sqrt(v_hat) + epsilon) elbo_est += objective(flattened_current_params) t += 1 print("Epoch: %d ELBO: %e" % (epoch, elbo_est / np.ceil(N / batch_size))) # We save the trained params so we don't have to retrain each time np.save(os.path.join("trained_params", "data.npy"), flattened_current_params) # We obtain the final trained parameters flattened_current_params = np.load(os.path.join("trained_params", "data.npy")) gen_params, rec_params = unflat_params(flattened_current_params) ####---- Task 3.1 ---#### # We generate 25 samples from the prior. # Note the prior P(z) is a standard Gaussian num_prior_images = 25 z = npr.randn(num_prior_images, latent_dim) # Generate the images using the prior and gen params generated_images = neural_net_predict(gen_params, z)
def fit_weights_and_save(weights_file,ca_data_file='rs_sc_fg_ret_pval_0_05_210423.npy',opto_silencing_data_file='vip_halo_data_for_sim.npy',opto_activation_data_file='vip_chrimson_data_for_sim.npy',constrain_wts=None,allow_var=True,multiout=True,multiout2=False,fit_s02=True,constrain_isn=True,tv=False,l2_penalty=0.01,l1_penalty=1.0,init_noise=0.1,init_W_from_lsq=False,scale_init_by=1,init_W_from_file=False,init_file=None,foldT=False,free_amplitude=False,correct_Eta=False,init_Eta_with_s02=False,no_halo_res=False,ignore_halo_vip=False,use_opto_transforms=False,norm_opto_transforms=False,nondim=False,fit_running=False,fit_non_running=True,fit_sc=True,fit_fg=False,fit_ret=False,run_modulation=True,axon=True,nT=2): nsize,ncontrast = 6,6 nrun = 2 nsize,ncontrast,ndir = 6,6,8 nstim_fg = 5 nstim_ret = 5 print((fit_sc,fit_fg,fit_ret)) fit_both_running = (fit_non_running and fit_running) fit_all_stims = (fit_sc and fit_fg and fit_ret) if not fit_both_running: nrun = 1 if fit_non_running: irun = 0 elif fit_running: irun = 1 nsc = nrun*nsize*ncontrast*ndir nfg = nrun*nstim_fg*ndir nret = nrun*nstim_ret npfile = np.load(ca_data_file,allow_pickle=True)[()]#,{'rs':rs},allow_pickle=True) # ,'rs_denoise':rs_denoise if fit_both_running: Rs_mean = npfile['Rs_mean_run'] Rs_cov = npfile['Rs_cov_run'] else: Rs_mean = npfile['Rs_mean'][irun] Rs_cov = npfile['Rs_cov'][irun] if not fit_all_stims: slcs = [] if fit_sc: slcs = slcs + [slice(None,nsc)] if fit_fg: slcs = slcs + [slice(nsc,nsc+nfg)] if fit_ret: slcs = slcs + [slice(nsc+nfg,None)] Rs_mean,Rs_cov = get_Rs_slices(Rs_mean,Rs_cov,slcs) if nT==3: ori_dirs = [[0,4],[1,3,5,7],[2,6]] elif nT==2: ori_dirs = [[0,4],[2,6]] #[[0,4],[1,3,5,7],[2,6]] else: nT = 1 ori_dirs = [[0,1,2,3,4,5,6,7]] ndims = 5 nT = len(ori_dirs) nS = len(Rs_mean[0]) fg_start = nsc*fit_sc ret_start = fg_start + nfg*fit_fg def sum_to_1(r): R = r.reshape((r.shape[0],-1)) R = R/np.nansum(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis] # changed 21/4/10 return R def norm_to_mean(r): R = r.reshape((r.shape[0],-1)) R = R/np.nanmean(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis] return R def ori_avg(Rs,these_ori_dirs): if fit_sc: rs_sc = np.nanmean(Rs[:nsc].reshape((nrun,nsize,ncontrast,ndir))[:,:,:,these_ori_dirs],-1) rs_sc[:,1:,1:] = ssi.convolve(rs_sc,kernel,'valid') rs_sc = rs_sc.reshape((nrun*nsize*ncontrast)) else: rs_sc = np.zeros((0,)) if fit_fg: rs_fg = np.nanmean(Rs[fg_start:fg_start+nfg].reshape((nrun,nstim_fg,ndir))[:,:,these_ori_dirs],-1) rs_fg = rs_fg.reshape((nrun*nstim_fg)) else: rs_fg = np.zeros((0,)) if fit_ret: rs_ret = Rs[ret_start:ret_start+nret] else: rs_ret = np.zeros((0,)) Rso = np.concatenate((rs_sc,rs_fg,rs_ret)) return Rso Rso_mean = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(Rs_mean))] Rso_cov = [[[[[None,None] for idim in range(ndims)] for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(Rs_mean))] kernel = np.ones((1,2,2)) kernel = kernel/kernel.sum() for iR,r in enumerate(Rs_mean): for ialign in range(nS): for iori in range(nT): Rso_mean[iR][ialign][iori] = ori_avg(Rs_mean[iR][ialign],ori_dirs[iori]) for idim in range(ndims): Rso_cov[iR][ialign][iori][idim][0] = Rs_cov[iR][ialign][idim][0] Rso_cov[iR][ialign][iori][idim][1] = ori_avg(Rs_cov[iR][ialign][idim][1],ori_dirs[iori]) def set_bound(bd,code,val=0): # set bounds to 0 where 0s occur in 'code' for iitem in range(len(bd)): bd[iitem][code[iitem]] = val nN = (36*fit_sc + 5*fit_fg + 5*fit_ret)*(1 + fit_both_running) nS = 2 nP = 2 + fit_both_running #nT = 2 nQ = 4 ndims = 5 ncelltypes = 5 #print('foldT: %d'%foldT) if foldT: Yhat = [None for iS in range(nS)] Xhat = [None for iS in range(nS)] Ypc_list = [None for iS in range(nS)] Xpc_list = [None for iS in range(nS)] print('have not written this yet') assert(True==False) else: Yhat = [[None for iT in range(nT)] for iS in range(nS)] Xhat = [[None for iT in range(nT)] for iS in range(nS)] Ypc_list = [[None for iT in range(nT)] for iS in range(nS)] Xpc_list = [[None for iT in range(nT)] for iS in range(nS)] for iS in range(nS): mx = np.zeros((ncelltypes,)) yy = [None for icelltype in range(ncelltypes)] for icelltype in range(ncelltypes): yy[icelltype] = np.concatenate(Rso_mean[icelltype][iS]) mx[icelltype] = np.nanmax(yy[icelltype]) for iT in range(nT): y = [Rso_mean[icelltype][iS][iT][:,np.newaxis]/mx[icelltype] for icelltype in range(1,ncelltypes)] Yhat[iS][iT] = np.concatenate(y,axis=1) Ypc_list[iS][iT] = [None for icelltype in range(1,ncelltypes)] for icelltype in range(1,ncelltypes): Ypc_list[iS][iT][icelltype-1] = [(this_dim[0]/mx[icelltype],this_dim[1]) for this_dim in Rso_cov[icelltype][iS][iT]] icelltype = 0 x = Rso_mean[icelltype][iS][iT][:,np.newaxis]/mx[icelltype] if fit_both_running: run_vector = np.zeros_like(x) if fit_sc: run_vector[int(nsc/ndir/2):int(nsc/ndir)] = 1 #print(run_vector.sum()) if fit_fg: #print(int((fg_start+nfg/2)/ndir)) #print(int((fg_start+nfg)/ndir)) run_vector[int((fg_start+nfg/2)/ndir):int((fg_start+nfg)/ndir)] = 1 #print(run_vector.sum()) if fit_ret: run_vector[int(ret_start/ndir+nret/2):int(ret_start/ndir+nret)] = 1 print('run vector mean: '+str(run_vector.mean())) print('run vector shape: '+str(run_vector.shape)) else: run_vector = np.zeros((x.shape[0],0)) Xhat[iS][iT] = np.concatenate((x,np.ones_like(x),run_vector),axis=1) Xpc_list[iS][iT] = [None for iinput in range(2+fit_both_running)] Xpc_list[iS][iT][0] = [(this_dim[0]/mx[icelltype],this_dim[1]) for this_dim in Rso_cov[icelltype][iS][iT]] Xpc_list[iS][iT][1] = [(0,np.zeros((Xhat[0][0].shape[0],))) for idim in range(ndims)] if fit_both_running: Xpc_list[iS][iT][2] = [(0,np.zeros((Xhat[0][0].shape[0],))) for idim in range(ndims)] nN,nP = Xhat[0][0].shape nQ = Yhat[0][0].shape[1] # code for bounds: 0 , constrained to 0 # +/-1 , constrained to +/-1 # 1.5, constrained to [0,1] # -1.5, constrained to [-1,1] # 2 , constrained to [0,inf) # -2 , constrained to (-inf,0] # 3 , unconstrained W0x_bounds = 3*np.ones((nP,nQ),dtype=int) W0x_bounds[0,:] = 2 # L4 PCs are excitatory W0x_bounds[0,1] = 0 # SSTs don't receive L4 input if allow_var: if nondim: W1x_bounds = -1.5*np.ones(W0x_bounds.shape) #W0x_bounds.copy()*0 #np.zeros_like(W0x_bounds) else: W1x_bounds = 3*np.ones(W0x_bounds.shape) #W0x_bounds.copy()*0 #np.zeros_like(W0x_bounds) W1x_bounds[0,1] = 0 else: W1x_bounds = np.zeros(W0x_bounds.shape) #W0x_bounds.copy()*0 #np.zeros_like(W0x_bounds) W0y_bounds = 3*np.ones((nQ,nQ),dtype=int) W0y_bounds[0,:] = 2 # PCs are excitatory W0y_bounds[1:,:] = -2 # all the cell types except PCs are inhibitory W0y_bounds[1,1] = 0 # SSTs don't inhibit themselves # W0y_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al. W0y_bounds[2,0] = 0 # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition W0y_bounds[2,2] = 0 # newly added: no VIP-VIP inhibition if not constrain_wts is None: for wt in constrain_wts: W0y_bounds[wt[0],wt[1]] = 0 W1y_bounds[wt[0],wt[1]] = 0 def tile_nS_nT_nN(kernel): row = np.concatenate([kernel for idim in range(nS*nT)],axis=0)[np.newaxis,:] tiled = np.concatenate([row for irow in range(nN)],axis=0) return tiled if fit_s02: s02_bounds = 2*np.ones((nQ,)) # permitting noise as a free parameter else: s02_bounds = np.ones((nQ,)) Kin0_bounds = 1.5*np.ones((nQ*(nS>1),)) kappa_bounds = np.ones((nS-1,)) # kappa_bounds = 2*np.ones((1,)) Tin0_bounds = 1.5*np.ones((nQ*(nT>1),)) #T_bounds[2:4] = 1 # PV and VIP are constrained to have flat ori tuning #Tin0_bounds[1:4] = 1 # SST,VIP, and PV are constrained to have flat ori tuning if nondim: kt_factor = -1.5 else: kt_factor = 3 if allow_var: W1y_bounds = kt_factor*np.ones(W0y_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds) Kin1_bounds = kt_factor*np.ones(Kin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds) Tin1_bounds = kt_factor*np.ones(Tin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds) W1y_bounds[1,1] = 0 #W1y_bounds[3,1] = 0 W1y_bounds[2,0] = 0 W1y_bounds[2,2] = 0 # newly added: no VIP-VIP inhibition else: W1y_bounds = np.zeros(W0y_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds) Kin1_bounds = 0*np.ones(Kin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds) Tin1_bounds = 0*np.ones(Tin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds) if multiout: W2x_bounds = W1x_bounds.copy() W2y_bounds = W1y_bounds.copy() if multiout2: W3x_bounds = W1x_bounds.copy() W3y_bounds = W1y_bounds.copy() else: W3x_bounds = W1x_bounds.copy()*0 W3y_bounds = W1y_bounds.copy()*0 else: W2x_bounds = W1x_bounds.copy()*0 W2y_bounds = W1y_bounds.copy()*0 W3x_bounds = W1x_bounds.copy()*0 W3y_bounds = W1y_bounds.copy()*0 if run_modulation: W0xrun_bounds = -1.5*(W0x_bounds!=0).astype('int') W0yrun_bounds = -1.5*(W0y_bounds!=0).astype('int') else: W0xrun_bounds = W0x_bounds.copy()*0 W0yrun_bounds = W0y_bounds.copy()*0 if nS>1: Kxout0_bounds = np.array((1.5,)+tuple(np.zeros((nP-1,)))) Kxout1_bounds = np.array((kt_factor,)+tuple(np.zeros((nP-1,)))) else: Kxout0_bounds = np.zeros((0,)) Kxout1_bounds = np.zeros((0,)) if nT>1: Txout0_bounds = Kxout0_bounds.copy() Txout1_bounds = Kxout1_bounds.copy() else: Txout0_bounds = np.zeros((0,)) Txout1_bounds = np.zeros((0,)) Kyout0_bounds = Kin0_bounds.copy() Tyout0_bounds = Tin0_bounds.copy() Kyout1_bounds = Kin1_bounds.copy() Tyout1_bounds = Tin1_bounds.copy() if not axon: Kxout0_bounds = np.ones_like(Kxout0_bounds) Txout0_bounds = np.ones_like(Txout0_bounds) Kxout1_bounds = np.zeros_like(Kxout1_bounds) Txout1_bounds = np.zeros_like(Txout1_bounds) Kyout0_bounds = np.ones_like(Kyout0_bounds) Tyout0_bounds = np.ones_like(Tyout0_bounds) Kyout1_bounds = np.zeros_like(Kyout1_bounds) Tyout1_bounds = np.zeros_like(Tyout1_bounds) if fit_both_running: to_tile = Xhat[0][0][:,1:] to_tile = np.concatenate((2*np.ones((to_tile.shape[0],1)),to_tile),axis=1) X_bounds = np.tile(to_tile,(1,nS*nT)) else: X_bounds = tile_nS_nT_nN(np.array([2,1])) #print(X_bounds.shape) # X_bounds = np.array([np.array([2,1,2,1])]*nN) if fit_both_running: Xp_bounds = tile_nS_nT_nN(np.array([3,0,0])) # edited to set XXp to 0 for spont. term else: Xp_bounds = tile_nS_nT_nN(np.array([3,0])) # edited to set XXp to 0 for spont. term if not multiout: Xp_bounds = Xp_bounds*0 # Xp_bounds = np.array([np.array([3,1,3,1])]*nN) # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,))) # # Y_bounds = 2*np.ones((nN,nT*nS*nQ)) Eta_bounds = tile_nS_nT_nN(3*np.ones((nQ,))) # Eta_bounds = 3*np.ones((nN,nT*nS*nQ)) if allow_var: Xi_bounds = tile_nS_nT_nN(3*np.ones((nQ,))) else: Xi_bounds = tile_nS_nT_nN(np.zeros((nQ,))) #Xi_bounds = tile_nS_nT_nN(3*np.ones((nQ,))) # temporarily allowing Xi even if W1 is not allowed # Xi_bounds = 3*np.ones((nN,nT*nS*nQ)) h1_bounds = -2*np.ones((1,)) h2_bounds = 2*np.ones((1,)) bl_bounds = 3*np.ones((nQ,)) if free_amplitude: amp_bounds = 2*np.ones((nT*nS*nQ,)) else: amp_bounds = 1*np.ones((nT*nS*nQ,)) # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)] #shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,)] shapes1 = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS>1),),(nQ*(nS>1),),(nP*(nS>1),),(nQ*(nS>1),),(nP*(nS>1),),(nQ*(nS>1),),(1,),(nQ*(nT>1),),(nQ*(nT>1),),(nP*(nT>1),),(nQ*(nT>1),),(nP*(nT>1),),(nQ*(nT>1),),(1,),(1,),(nQ,),(nT*nS*nQ,)] print('shapes1: '+str(shapes1)) shapes2 = [(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ)] # W0x, W0y, W1x, W1y, W2x, W2y, W3x, W3y, s02, k, kappa,T, XX, XXp, Eta, Xi #lb = [-np.inf*np.ones(shp) for shp in shapes] #ub = [np.inf*np.ones(shp) for shp in shapes] #bdlist = [W0x_bounds,W0y_bounds,W1x_bounds,W1y_bounds,W2x_bounds,W2y_bounds,W3x_bounds,W3y_bounds,s02_bounds,k0_bounds,k1_bounds,k2_bounds,k3_bounds,kappa_bounds,Tin0_bounds,Tin1_bounds,Tout0_bounds,Tout1_bounds,X_bounds,Xp_bounds,Eta_bounds,Xi_bounds,h_bounds] bd1list = [W0x_bounds,W0y_bounds,W1x_bounds,W1y_bounds,W2x_bounds,W2y_bounds,W3x_bounds,W3y_bounds,W0xrun_bounds,W0yrun_bounds,s02_bounds,Kin0_bounds,Kin1_bounds,Kxout0_bounds,Kyout0_bounds,Kxout1_bounds,Kyout1_bounds,kappa_bounds,Tin0_bounds,Tin1_bounds,Txout0_bounds,Tyout0_bounds,Txout1_bounds,Tyout1_bounds,h1_bounds,h2_bounds,bl_bounds,amp_bounds] bd2list = [X_bounds,Xp_bounds,Eta_bounds,Xi_bounds] print('bd1 shapes: '+str([b.shape for b in bd1list])) lb1,ub1 = [[sgn*np.inf*np.ones(shp) for shp in shapes1] for sgn in [-1,1]] lb1,ub1 = calnet.utils.set_bounds_by_code(lb1,ub1,bd1list) lb2,ub2 = [[sgn*np.inf*np.ones(shp) for shp in shapes2] for sgn in [-1,1]] lb2,ub2 = calnet.utils.set_bounds_by_code(lb2,ub2,bd2list) lb1 = np.concatenate([a.flatten() for a in lb1]) ub1 = np.concatenate([b.flatten() for b in ub1]) lb2 = np.concatenate([a.flatten() for a in lb2]) ub2 = np.concatenate([b.flatten() for b in ub2]) bounds1 = [(a,b) for a,b in zip(lb1,ub1)] bounds2 = [(a,b) for a,b in zip(lb2,ub2)] def compute_f_(Eta,Xi,s02): return sim_utils.f_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)])) def compute_fprime_m_(Eta,Xi,s02): return sim_utils.fprime_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)]))*Xi def compute_fprime_s_(Eta,Xi,s02): s2 = Xi**2+np.concatenate((s02,s02),axis=0) return sim_utils.fprime_s_miller_troyer(Eta,s2)*(Xi/s2) def sorted_r_eigs(w): drW,prW = np.linalg.eig(w) srtinds = np.argsort(drW) return drW[srtinds],prW[:,srtinds] #0.W0x,1.W0y,2.W1x,3.W1y,4.W2x,5.W2y,6.W3x,7.W3y,8.s02,9.Kin0,10.Kin1,11.Kout0,12.Kout1,13.kappa,14.Tin0,15.Tin1,16.Txout0,Tyout0,17.Txout1,Tyout1,18.h1,19.h2,20.bl,21.amp #0.XX,1.XXp,2.Eta,3.Xi #shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,)] #shapes1 = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(1,),(1,),(nQ,),(nT*nS*nQ,)] #shapes2 = [(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ)] import sim_utils YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat) opto_dict = np.load(opto_silencing_data_file,allow_pickle=True)[()] Yhat_opto = opto_dict['Yhat_opto'] Yhat_opto = np.ones((nN*2,nQ*nS*nT)) #Yhat_opto = Yhat_opto.reshape((nN*2,-1)) Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12],axis=0)[np.newaxis] Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12],axis=0)[np.newaxis] Yhat_opto = Yhat_opto/np.nanmax(Yhat_opto[0::2],0)[np.newaxis,:] #print(Yhat_opto.shape) h_opto = np.zeros((nN*2,)) #h_opto = opto_dict['h_opto'] #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2] YYhat_halo = Yhat_opto.reshape((nN,2,-1)) opto_transform1 = calnet.utils.fit_opto_transform(YYhat_halo,norm01=norm_opto_transforms) if no_halo_res: opto_transform1.res[:,[0,2,3,4,6,7]] = 0 dYY1 = opto_transform1.transform(YYhat) - opto_transform1.preprocess(YYhat) #print('delta bias: %f'%dXX1[:,1].mean()) #YYhat_halo_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_halo) #dYY1 = YYhat_halo_sim[:,1,:] - YYhat_halo_sim[:,0,:] def overwrite_plus_n(arr,to_overwrite,n): arr[:,to_overwrite] = arr[:,int(to_overwrite+n)] return arr for to_overwrite in [1,2]: n = 4 dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \ = [overwrite_plus_n(x,to_overwrite,n) for x in \ [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]] for to_overwrite in [7]: n = -4 dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \ = [overwrite_plus_n(x,to_overwrite,n) for x in \ [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]] opto_dict = np.load(opto_activation_data_file,allow_pickle=True)[()] Yhat_opto = opto_dict['Yhat_opto'] Yhat_opto = np.ones((nN*2,nQ*nS*nT)) #Yhat_opto = Yhat_opto.reshape((nN*2,-1)) Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12],axis=0)[np.newaxis] Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12],axis=0)[np.newaxis] Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:] #print(Yhat_opto.shape) h_opto = np.zeros((nN*2,)) #h_opto = opto_dict['h_opto'] #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2] YYhat_chrimson = Yhat_opto.reshape((nN,2,-1)) opto_transform2 = calnet.utils.fit_opto_transform(YYhat_chrimson,norm01=norm_opto_transforms) dYY2 = opto_transform2.transform(YYhat) - opto_transform2.preprocess(YYhat) dYY = np.concatenate((dYY1,dYY2),axis=0) if ignore_halo_vip: dYY1[:,2::nQ] = np.nan #from importlib import reload #reload(calnet) #reload(calnet.fitting_2step_spatial_feature_opto_multiout_axon_nonlinear) #reload(sim_utils) wt_dict = {} wt_dict['X'] = 3 wt_dict['Y'] = 3 wt_dict['Eta'] = 3# 10 wt_dict['Xi'] = 3 wt_dict['stims'] = np.ones((nN,1)) #(np.arange(30)/30)[:,np.newaxis]**1 # wt_dict['barrier'] = 0. #30.0 #0.1 wt_dict['opto'] = 0#1e0#1e-1#1e1 wt_dict['smi'] = 0 wt_dict['isn'] = 0.1 wt_dict['tv'] = 0.1 YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat) XXhat = calnet.utils.flatten_nested_list_of_2d_arrays(Xhat) Eta0 = invert_f_mt(YYhat) Xi0 = invert_fprime_mt(Ypc_list,Eta0,nN=nN,nQ=nQ,nS=nS,nT=nT,foldT=foldT) ntries = 1 nhyper = 1 dt = 1e-1 niter = int(np.round(10/dt)) #int(1e4) perturbation_size = 5e-2 W1t = [[None for itry in range(ntries)] for ihyper in range(nhyper)] W2t = [[None for itry in range(ntries)] for ihyper in range(nhyper)] loss = np.zeros((nhyper,ntries)) is_neg = np.array([b[1] for b in bounds1])==0 counter = 0 negatize = [np.zeros(shp,dtype='bool') for shp in shapes1] for ishp,shp in enumerate(shapes1): nel = np.prod(shp) negatize[ishp][:][is_neg[counter:counter+nel].reshape(shp)] = True counter = counter + nel for ihyper in range(nhyper): for itry in range(ntries): print((ihyper,itry)) W10list = [init_noise*(ihyper+1)*np.random.rand(*shp) for shp in shapes1] W20list = [init_noise*(ihyper+1)*np.random.rand(*shp) for shp in shapes2] counter = 0 for ishp,shp in enumerate(shapes1): W10list[ishp][negatize[ishp]] = -W10list[ishp][negatize[ishp]] nextraW = 4 nextraK = nextraW + 3 nextraT = nextraK + 3 #Wstar_dict['as_list'] = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h1,h2,bl,amp]#,h2 init_val = [1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1] W10list = [iv*np.ones(shp) for iv,shp in zip(init_val,shapes1)] #W10list[nextraW+4] = np.ones(shapes1[nextraW+4]) # s02 #W10list[nextraW+5] = np.ones(shapes1[nextraW+5]) # K #W10list[nextraW+6] = np.ones(shapes1[nextraW+6]) # K #W10list[nextraW+7] = np.zeros(shapes1[nextraW+7]) # K #W10list[nextraW+8] = np.zeros(shapes1[nextraW+8]) # K #W10list[nextraK+6] = np.ones(shapes1[nextraK+6]) # kappa #W10list[nextraK+7] = np.ones(shapes1[nextraK+7]) # T #W10list[nextraK+8] = np.ones(shapes1[nextraK+8]) # T #W10list[nextraK+9] = np.zeros(shapes1[nextraK+9]) # T #W10list[nextraK+10] = np.zeros(shapes1[nextraK+10]) # T W20list[0] = XXhat #np.concatenate(Xhat,axis=1) #XX W20list[1] = get_pc_dim(Xpc_list,nN=nN,nPQ=nP,nS=nS,nT=nT,idim=0,foldT=foldT) #XXp W20list[2] = Eta0 #np.zeros(shapes[nextraT+10]) #Eta W20list[3] = Xi0 #Xi #print(XXhat.shape) isn_init = np.array(((3,5),(-5,-5))) nvar,nxy = 4,2 freeze_vals = [[None for _ in range(nxy)] for _ in range(nvar)] for ivar in range(nvar): for ixy in range(nxy): iflat = np.ravel_multi_index((ivar,ixy),(nvar,nxy)) freeze_vals[ivar][ixy] = np.zeros(bd1list[iflat].shape) freeze_vals[ivar][ixy][bd1list[iflat]==0] = np.nan if init_W_from_lsq: # shapes1 #0.W0x,1.W0y,2.W1x,3.W1y,4.W2x,5.W2y,6.W3x,7.W3y,8.s02,9.Kin0,10.Kin1,11.Kout0,12.Kout1,13.kappa,14.Tin0,15.Tin1,16.Txout0,Tyout0,17.Txout1,Tyout1,18.h1,19.h2,20.bl,21.amp # shapes2 #0.XX,1.XXp,2.Eta,3.Xi #W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,Kin0,Kin1,Tin0,Tin1 = initialize_Ws(Xhat,Yhat,Xpc_list,Ypc_list,scale_by=1) lams = 1e5*np.array((0,1,1,1,0,1,0,1)) if constrain_isn: freeze_vals[0][1][slice(0,None,3)][:,slice(0,None,3)] = isn_init #Wlist = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1] # W1list = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp]#,h2 # W0,W1,W2,W3,Kin0,Kin1,Tin0,Tin1 thisWlist = initialize_Ws(Xhat,Yhat,Xpc_list,Ypc_list,scale_by=1,freeze_vals=freeze_vals,lams=lams,foldT=foldT) #Winds = [0,1,2,3,4,5,6,7,9,10,11,12,13,14,16,17,18,19,20,21] Winds = [0,1,2,3,4,5,6,7,11,12,13,14,15,16,18,19,20,21,22,23] for ivar,Wind in enumerate(Winds): if shapes1[Wind] == thisWlist[ivar].shape: W10list[Wind] = thisWlist[ivar] #W10list[0],W10list[1] = initialize_W(Xhat,Yhat,scale_by=scale_init_by) for Wind in Winds+[8,9]: W10list[Wind] = W10list[Wind] + init_noise*np.random.randn(*W10list[Wind].shape) else: if constrain_isn: W10list[1][slice(0,None,3)][:,slice(0,None,3)] = isn_init #W10list[1][0,0] = 3 #W10list[1][0,3] = 5 #W10list[1][3,0] = -5 #W10list[1][3,3] = -5 if init_W_from_file: # did not adjust this yet #Wstar_dict['as_list'] = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,W0xrun,W0yrun,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h1,h2,bl,amp]#,h2 npyfile = np.load(init_file,allow_pickle=True)[()] print(len(npyfile['as_list'])) print([w.shape for w in npyfile['as_list']]) W10list = [npyfile['as_list'][ivar] for ivar in list(np.arange(24))+[28,29,30,31]] W20list = [npyfile['as_list'][ivar] for ivar in [24,25,26,27]] if correct_Eta: #assert(True==False) W20list[2] = Eta0.copy() #if len(W10list) < len(shapes1): # #assert(True==False) # W10list = W10list + [np.array(1),np.zeros((nQ,)),np.zeros((nT*nS*nQ,))] # add bl, amp #np.array(1), #h2, #W10 = unparse_W(W10list) #W20 = unparse_W(W20list) opt = fmc.gen_opt_axon(nN=nN,nP=nP,nQ=nQ,nS=nS,nT=nT,foldT=foldT,nondim=nondim,run_modulation=run_modulation) #opt = fmc.gen_opt() #resEta0,resXi0 = fmc.compute_res(W10,W20,opt) #if init_W1xy_with_res: # W1x0,W1y0,Kin10,Tin10 = optimize_W1xy(W10list,W20list,opt) # W0list[2] = W1x0 # W0list[3] = W1y0 # W0list[10] = Kin10 # W0list[15] = Tin10 #if init_W2xy_with_res: # W2x0,W2y0 = optimize_W2xy(W10list,W20list,opt) # W0list[4] = W2x0 # W0list[5] = W2y0 if init_Eta_with_s02: #assert(True==False) s02 = W10list[4].copy() Eta0 = invert_f_mt_with_s02(YYhat,s02,nS=nS,nT=nT) W20list[2] = Eta0.copy() for ivar in range(len(W10list)):#[0,1,4,5]: # Wmx, Wmy, s02, k #print(init_noise) W10list[ivar] = W10list[ivar] + init_noise*np.random.randn(*W10list[ivar].shape) #W0list = npyfile['as_list'] np.save('/home/dan/calnet_data/W0list.npy',{'W10list':W10list,'W20list':W20list,'bd1list':bd1list,'bd2list':bd2list,'freeze_vals':freeze_vals,'bounds1':bounds1,'bounds2':bounds2},allow_pickle=True) #extra_Ws = [np.zeros_like(W10list[ivar]) for ivar in range(2)] #extra_ks = [np.zeros_like(W10list[5]) for ivar in range(3)] #extra_Ts = [np.zeros_like(W10list[7]) for ivar in range(3)] #W10list = W10list[:4] + extra_Ws*2 + W10list[4:6] + extra_ks + W10list[6:8] + extra_Ts + W10list[8:] #print(len(W10list)) W1t[ihyper][itry],W2t[ihyper][itry],loss[ihyper][itry],gr,hess,result = calnet.fitting_2step_spatial_feature_opto_multiout_axon_nonlinear_run_modulation.fit_W_sim(Xhat,Xpc_list,Yhat,Ypc_list,pop_rate_fn=sim_utils.f_miller_troyer,pop_deriv_fn=sim_utils.fprime_miller_troyer,neuron_rate_fn=sim_utils.evaluate_f_mt,W10list=W10list.copy(),W20list=W20list.copy(),bounds1=bounds1,bounds2=bounds2,niter=niter,wt_dict=wt_dict,l2_penalty=l2_penalty,l1_penalty=l1_penalty,compute_hessian=False,dt=dt,perturbation_size=perturbation_size,dYY=dYY,constrain_isn=constrain_isn,tv=tv,foldT=foldT,use_opto_transforms=use_opto_transforms,opto_transform1=opto_transform1,opto_transform2=opto_transform2,nondim=nondim,run_modulation=run_modulation) #def parse_W(W): # W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kout0,Kout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h = W # return W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kout0,Kout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h def parse_W1(W): W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,W0xrun,W0yrun,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp = W #h2, return W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,W0xrun,W0yrun,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp #h2, def parse_W2(W): XX,XXp,Eta,Xi = W return XX,XXp,Eta,Xi def unparse_W(Ws): return np.concatenate([ww.flatten() for ww in Ws]) itry = 0 W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,W0xrun,W0yrun,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp = parse_W1(W1t[0][0])#h2, XX,XXp,Eta,Xi = parse_W2(W2t[0][0]) #labels = ['W0x','W0y','W1x','W1y','W2x','W2y','W3x','W3y','s02','Kin0','Kin1','Kout0','Kout1','kappa','Tin0','Tin1','Tout0','Tout1','XX','XXp','Eta','Xi','h'] labels1 = ['W0x','W0y','W1x','W1y','W2x','W2y','W3x','W3y','W0xrun','W0yrun','s02','Kin0','Kin1','Kxout0','Kyout0','Kxout1','Kyout1','kappa','Tin0','Tin1','Txout0','Tyout0','Txout1','Tyout1','h1','h2','bl','amp']#,'h2' labels2 = ['XX','XXp','Eta','Xi'] Wstar_dict = {} for i,label in enumerate(labels1): Wstar_dict[label] = W1t[0][0][i] for i,label in enumerate(labels2): Wstar_dict[label] = W2t[0][0][i] #Wstar_dict = {} #for i,label in enumerate(labels): # Wstar_dict[label] = W1t[0][0][i] Wstar_dict['as_list'] = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,W0xrun,W0yrun,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h1,h2,bl,amp]#,h2 Wstar_dict['loss'] = loss[0][0] Wstar_dict['wt_dict'] = wt_dict np.save(weights_file,Wstar_dict,allow_pickle=True)
# parameters to learn param_vec = np.array([-0.2, -0.2, -0.2, -0.2, 1.5, 1.5]) labels = ['N1', 'N2', 'C1', 'C2', 'C0'] def grad_wrapper(param_vec, i): ''' for the autograd optimisers ''' return grad_func(param_vec, N, Cin, next_N, C, C_0) ''' xSol, Cins = create_time_series('/Users/Neythen/masters_project/app/CBcurl_master/examples/parameter_files/unstable_2_species.yaml', '/Users/Neythen/masters_project/results/lookup_table_results/LT_unstable_repeats/repeat1/Q_table.npy', 10000) np.save('unstable.npy', xSol) np.save('unstable_Cins.npy', Cins) ''' xSol = np.load( '/Users/Neythen/Desktop/masters_project/parameter_estimation/system_trajectories/double_aux.npy' ) Cins = np.load( '/Users/Neythen/Desktop/masters_project/parameter_estimation/system_trajectories/double_aux_Cins.npy' ) fullSol = xSol losses = []
feat_list.append(d['features']) idx += 1 # Now loop over each node again and figure out its neighbors. for n, d in p.nodes(data=True): graph_idxs[project['title']].append(d['idx']) nodes_nbrs[d['idx']].append(d['idx']) graph_nodes[project['title']][d['idx']] = n for nbr in p.neighbors(n): nodes_nbrs[d['idx']].append(p.node[nbr]['idx']) # print(nodes_nbrs[d['idx']]) except: print('Did not make graph for {0}'.format(project['code'])) # Save the data to disk: # The array... feat_array = np.vstack(feat_list) np.save('../data/feat_array.npy', feat_array) # The node idxs and their neighbor idxs... with open('../data/nodes_nbrs.pkl', 'wb') as f: pkl.dump(nodes_nbrs, f) # The graphs' seqids and their node idxs... with open('../data/graph_idxs.pkl', 'wb') as f: pkl.dump(graph_idxs, f) # The graphs': {'SeqID1':{1:'A51SER',...},...} with open('../data/graph_nodes.pkl', 'wb') as f: pkl.dump(graph_nodes, f)
def fit_weights_and_save( weights_file, ca_data_file='rs_vm_denoise_200605.npy', opto_silencing_data_file='vip_halo_data_for_sim.npy', opto_activation_data_file='vip_chrimson_data_for_sim.npy', constrain_wts=None, allow_var=True, fit_s02=True, constrain_isn=True, l2_penalty=0.01, init_noise=0.1, init_W_from_lsq=False, scale_init_by=1, init_W_from_file=False, init_file=None): nsize, ncontrast = 6, 6 # In[3]: npfile = np.load(ca_data_file, allow_pickle=True)[()] #,{'rs':rs},allow_pickle=True) rs = npfile['rs'] # In[4]: nsize, ncontrast, ndir = 6, 6, 8 ori_dirs = [[0, 4], [2, 6]] #[[0,4],[1,3,5,7],[2,6]] nT = len(ori_dirs) nS = len(rs[0]) def sum_to_1(r): R = r.reshape((r.shape[0], -1)) #R = R/np.nansum(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis] R = R / np.nansum(R, axis=1)[:, np.newaxis] # changed 8/28 return R def norm_to_mean(r): R = r.reshape((r.shape[0], -1)) R = R / np.nanmean(R[:, ~np.isnan(R.sum(0))], axis=1)[:, np.newaxis] return R Rs = [[None, None] for i in range(len(rs))] Rso = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(rs))] rso = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(rs))] for iR, r in enumerate(rs): #rs_denoise): print(iR) for ialign in range(nS): Rs[iR][ialign] = sum_to_1(r[ialign][:, :nsize, :]) # Rs[iR][ialign] = von_mises_denoise(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))) kernel = np.ones((1, 2, 2)) kernel = kernel / kernel.sum() for iR, r in enumerate(rs): for ialign in range(nS): for iori in range(nT): Rso[iR][ialign][iori] = np.nanmean( Rs[iR][ialign].reshape( (-1, nsize, ncontrast, ndir))[:, :, :, ori_dirs[iori]], -1) Rso[iR][ialign][iori][:, :, 0] = np.nanmean( Rso[iR][ialign][iori][:, :, 0], 1)[:, np.newaxis] Rso[iR][ialign][iori][:, 1:, 1:] = ssi.convolve( Rso[iR][ialign][iori], kernel, 'valid') Rso[iR][ialign][iori] = Rso[iR][ialign][iori].reshape( Rso[iR][ialign][iori].shape[0], -1) #kernel = np.ones((1,2,2)) #kernel = kernel/kernel.sum() # #for iR,r in enumerate(rs): # for ialign in range(nS): # for iori in range(nT): # Rso[iR][ialign][iori] = np.nanmean(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]],-1) # Rso[iR][ialign][iori] = ssi.convolve(Rso[iR][ialign][iori],kernel,'same') # Rso[iR][ialign][iori] = Rso[iR][ialign][iori].reshape(Rso[iR][ialign][iori].shape[0],-1) # In[6]: def set_bound(bd, code, val=0): # set bounds to 0 where 0s occur in 'code' for iitem in range(len(bd)): bd[iitem][code[iitem]] = val # In[7]: nN = 36 nS = 2 nP = 2 nT = 2 nQ = 4 # code for bounds: 0 , constrained to 0 # +/-1 , constrained to +/-1 # 1.5, constrained to [0,1] # 2 , constrained to [0,inf) # -2 , constrained to (-inf,0] # 3 , unconstrained Wmx_bounds = 3 * np.ones((nP, nQ), dtype=int) Wmx_bounds[0, 1] = 0 # SSTs don't receive L4 input if allow_var: Wsx_bounds = 3 * np.ones( Wmx_bounds.shape) #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds) Wsx_bounds[0, 1] = 0 else: Wsx_bounds = np.zeros( Wmx_bounds.shape) #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds) Wmy_bounds = 3 * np.ones((nQ, nQ), dtype=int) Wmy_bounds[0, :] = 2 # PCs are excitatory Wmy_bounds[1:, :] = -2 # all the cell types except PCs are inhibitory Wmy_bounds[1, 1] = 0 # SSTs don't inhibit themselves # Wmy_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al. Wmy_bounds[ 2, 0] = 0 # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition if allow_var: Wsy_bounds = 3 * np.ones( Wmy_bounds.shape) #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds) Wsy_bounds[1, 1] = 0 Wsy_bounds[3, 1] = 0 Wsy_bounds[2, 0] = 0 else: Wsy_bounds = np.zeros( Wmy_bounds.shape) #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds) if not constrain_wts is None: for wt in constrain_wts: Wmy_bounds[wt[0], wt[1]] = 0 Wsy_bounds[wt[0], wt[1]] = 0 def tile_nS_nT_nN(kernel): row = np.concatenate([kernel for idim in range(nS * nT)], axis=0)[np.newaxis, :] tiled = np.concatenate([row for irow in range(nN)], axis=0) return tiled if fit_s02: s02_bounds = 2 * np.ones( (nQ, )) # permitting noise as a free parameter else: s02_bounds = np.ones((nQ, )) k_bounds = 1.5 * np.ones((nQ, )) kappa_bounds = np.ones((1, )) # kappa_bounds = 2*np.ones((1,)) T_bounds = 1.5 * np.ones((nQ, )) X_bounds = tile_nS_nT_nN(np.array([2, 1])) # X_bounds = np.array([np.array([2,1,2,1])]*nN) Xp_bounds = tile_nS_nT_nN(np.array([3, 1])) # Xp_bounds = np.array([np.array([3,1,3,1])]*nN) # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,))) # # Y_bounds = 2*np.ones((nN,nT*nS*nQ)) Eta_bounds = tile_nS_nT_nN(3 * np.ones((nQ, ))) # Eta_bounds = 3*np.ones((nN,nT*nS*nQ)) if allow_var: Xi_bounds = tile_nS_nT_nN(3 * np.ones((nQ, ))) else: Xi_bounds = tile_nS_nT_nN(np.zeros((nQ, ))) # Xi_bounds = 3*np.ones((nN,nT*nS*nQ)) h1_bounds = -2 * np.ones((1, )) h2_bounds = 2 * np.ones((1, )) # In[8]: # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)] shapes = [(nP, nQ), (nQ, nQ), (nP, nQ), (nQ, nQ), (nQ, ), (nQ, ), (1, ), (nQ, ), (nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ), (nN, nT * nS * nQ), (1, ), (1, )] # Wmx, Wmy, Wsx, Wsy, s02, k, kappa,T, XX, XXp, Eta, Xi lb = [-np.inf * np.ones(shp) for shp in shapes] ub = [np.inf * np.ones(shp) for shp in shapes] bdlist = [ Wmx_bounds, Wmy_bounds, Wsx_bounds, Wsy_bounds, s02_bounds, k_bounds, kappa_bounds, T_bounds, X_bounds, Xp_bounds, Eta_bounds, Xi_bounds, h1_bounds, h2_bounds ] set_bound(lb, [bd == 0 for bd in bdlist], val=0) set_bound(ub, [bd == 0 for bd in bdlist], val=0) set_bound(lb, [bd == 2 for bd in bdlist], val=0) set_bound(ub, [bd == -2 for bd in bdlist], val=0) set_bound(lb, [bd == 1 for bd in bdlist], val=1) set_bound(ub, [bd == 1 for bd in bdlist], val=1) set_bound(lb, [bd == 1.5 for bd in bdlist], val=0) set_bound(ub, [bd == 1.5 for bd in bdlist], val=1) set_bound(lb, [bd == -1 for bd in bdlist], val=-1) set_bound(ub, [bd == -1 for bd in bdlist], val=-1) # for bd in [lb,ub]: # for ind in [2,3]: # bd[ind][:,1] = 0 # temporary for no variation expt. # lb[2] = np.zeros_like(lb[2]) # lb[3] = np.zeros_like(lb[3]) # lb[4] = np.ones_like(lb[4]) # lb[5] = np.zeros_like(lb[5]) # ub[2] = np.zeros_like(ub[2]) # ub[3] = np.zeros_like(ub[3]) # ub[4] = np.ones_like(ub[4]) # ub[5] = np.ones_like(ub[5]) # temporary for no variation expt. lb = np.concatenate([a.flatten() for a in lb]) ub = np.concatenate([b.flatten() for b in ub]) bounds = [(a, b) for a, b in zip(lb, ub)] # In[10]: nS = 2 ndims = 5 ncelltypes = 5 Yhat = [[None for iT in range(nT)] for iS in range(nS)] Xhat = [[None for iT in range(nT)] for iS in range(nS)] Ypc_list = [[None for iT in range(nT)] for iS in range(nS)] Xpc_list = [[None for iT in range(nT)] for iS in range(nS)] for iS in range(nS): mx = np.zeros((ncelltypes, )) yy = [None for icelltype in range(ncelltypes)] for icelltype in range(ncelltypes): yy[icelltype] = np.nanmean(Rso[icelltype][iS][0], 0) mx[icelltype] = np.nanmax(yy[icelltype]) for iT in range(nT): y = [ np.nanmean(Rso[icelltype][iS][iT], axis=0)[:, np.newaxis] / mx[icelltype] for icelltype in range(1, ncelltypes) ] Ypc_list[iS][iT] = [None for icelltype in range(1, ncelltypes)] for icelltype in range(1, ncelltypes): rss = Rso[icelltype][iS][iT].copy( ) #.reshape(Rs[icelltype][ialign].shape[0],-1) rss = rss[np.isnan(rss).sum(1) == 0] # print(rss.max()) # rss[rss<0] = 0 # rss = rss[np.random.randn(rss.shape[0])>0] try: u, s, v = np.linalg.svd(rss - np.mean(rss, 0)[np.newaxis]) Ypc_list[iS][iT][icelltype - 1] = [ (s[idim], v[idim]) for idim in range(ndims) ] # print('yep on Y') # print(np.min(np.sum(rs[icelltype][iS][iT],axis=1))) except: # print('nope on Y') print(np.mean(np.isnan(rss))) print(np.min(np.sum(rs[icelltype][iS][iT], axis=1))) Yhat[iS][iT] = np.concatenate(y, axis=1) # x = sim_utils.columnize(Rso[0][iS][iT])[:,np.newaxis] icelltype = 0 x = np.nanmean(Rso[icelltype][iS][iT], 0)[:, np.newaxis] / mx[icelltype] # opto_column = np.concatenate((np.zeros((nN,)),np.zeros((nNO/2,)),np.ones((nNO/2,))),axis=0)[:,np.newaxis] Xhat[iS][iT] = np.concatenate((x, np.ones_like(x)), axis=1) # Xhat[iS][iT] = np.concatenate((x,np.ones_like(x),opto_column),axis=1) icelltype = 0 rss = Rso[icelltype][iS][iT].copy() rss = rss[np.isnan(rss).sum(1) == 0] # try: u, s, v = np.linalg.svd(rss - rss.mean(0)[np.newaxis]) Xpc_list[iS][iT] = [None for iinput in range(2)] Xpc_list[iS][iT][0] = [(s[idim], v[idim]) for idim in range(ndims)] Xpc_list[iS][iT][1] = [(0, np.zeros((Xhat[0][0].shape[0], ))) for idim in range(ndims)] # except: # print('nope on X') # print(np.mean(np.isnan(rss))) # print(np.min(np.sum(Rso[icelltype][iS][iT],axis=1))) nN, nP = Xhat[0][0].shape nQ = Yhat[0][0].shape[1] # In[11]: def compute_f_(Eta, Xi, s02): return sim_utils.f_miller_troyer( Eta, Xi**2 + np.concatenate([s02 for ipixel in range(nS * nT)])) def compute_fprime_m_(Eta, Xi, s02): return sim_utils.fprime_miller_troyer( Eta, Xi**2 + np.concatenate([s02 for ipixel in range(nS * nT)])) * Xi def compute_fprime_s_(Eta, Xi, s02): s2 = Xi**2 + np.concatenate((s02, s02), axis=0) return sim_utils.fprime_s_miller_troyer(Eta, s2) * (Xi / s2) def sorted_r_eigs(w): drW, prW = np.linalg.eig(w) srtinds = np.argsort(drW) return drW[srtinds], prW[:, srtinds] # In[12]: # 0.Wmx, 1.Wmy, 2.Wsx, 3.Wsy, 4.s02,5.K, 6.kappa,7.T,8.XX, 9.XXp, 10.Eta, 11.Xi shapes = [(nP, nQ), (nQ, nQ), (nP, nQ), (nQ, nQ), (nQ, ), (nQ, ), (1, ), (nQ, ), (nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ), (nN, nT * nS * nQ), (1, ), (1, )] # In[13]: import calnet.fitting_spatial_feature import sim_utils # In[14]: opto_dict = np.load(opto_silencing_data_file, allow_pickle=True)[()] Yhat_opto = opto_dict['Yhat_opto'] Yhat_opto = Yhat_opto / Yhat_opto[0::2].max(0)[np.newaxis, :] print(Yhat_opto.shape) h_opto = opto_dict['h_opto'] dYY1 = Yhat_opto[1::2] - Yhat_opto[0::2] for to_overwrite in [1, 2, 5, 6]: dYY1[:, to_overwrite] = dYY1[:, to_overwrite + 8] for to_overwrite in [11, 15]: dYY1[:, to_overwrite] = dYY1[:, to_overwrite - 8] opto_dict = np.load(opto_activation_data_file, allow_pickle=True)[()] Yhat_opto = opto_dict['Yhat_opto'] Yhat_opto = Yhat_opto / Yhat_opto[0::2].max(0)[np.newaxis, :] print(Yhat_opto.shape) h_opto = opto_dict['h_opto'] dYY2 = Yhat_opto[1::2] - Yhat_opto[0::2] print('dYY1 mean: %03f' % np.nanmean(np.abs(dYY1))) print('dYY2 mean: %03f' % np.nanmean(np.abs(dYY2))) dYY = np.concatenate((dYY1, dYY2), axis=0) opto_mask = ~np.isnan(dYY) dYY[~opto_mask] = 0 # In[ ]: from importlib import reload reload(calnet) #reload(calnet.fitting_spatial_feature_opto_bidi) reload(sim_utils) # reload(calnet.fitting_spatial_feature) # W0list = [np.ones(shp) for shp in shapes] wt_dict = {} wt_dict['X'] = 1 wt_dict['Y'] = 3 wt_dict['Eta'] = 1 # 10 wt_dict['Xi'] = 0.1 wt_dict['stims'] = np.ones((nN, 1)) #(np.arange(30)/30)[:,np.newaxis]**1 # wt_dict['barrier'] = 0. #30.0 #0.1 wt_dict['opto'] = 1e-1 #1e1 wt_dict['isn'] = 0.1 YYhat = calnet.fitting_spatial_feature_opto_bidi.flatten_nested_list_of_2d_arrays( Yhat) XXhat = calnet.fitting_spatial_feature_opto_bidi.flatten_nested_list_of_2d_arrays( Xhat) Eta0 = invert_f_mt(YYhat) ntries = 1 nhyper = 1 dt = 1e-1 niter = int(np.round(50 / dt)) #int(1e4) perturbation_size = 5e-2 # learning_rate = 1e-4 # 1e-5 #np.linspace(3e-4,1e-3,niter+1) # 1e-5 #l2_penalty = 0.1 Wt = [[None for itry in range(ntries)] for ihyper in range(nhyper)] loss = np.zeros((nhyper, ntries)) is_neg = np.array([b[1] for b in bounds]) == 0 counter = 0 negatize = [np.zeros(shp, dtype='bool') for shp in shapes] for ishp, shp in enumerate(shapes): nel = np.prod(shp) negatize[ishp][:][is_neg[counter:counter + nel].reshape(shp)] = True counter = counter + nel for ihyper in range(nhyper): for itry in range(ntries): print((ihyper, itry)) W0list = [ init_noise * (ihyper + 1) * np.random.rand(*shp) for shp in shapes ] counter = 0 for ishp, shp in enumerate(shapes): W0list[ishp][negatize[ishp]] = -W0list[ishp][negatize[ishp]] W0list[4] = np.ones(shapes[5]) # s02 W0list[5] = np.ones(shapes[5]) # K W0list[6] = np.ones(shapes[6]) # kappa W0list[7] = np.ones(shapes[7]) # T W0list[8] = np.concatenate(Xhat, axis=1) #XX W0list[9] = np.zeros_like(W0list[8]) #XXp W0list[10] = Eta0 #np.zeros(shapes[10]) #Eta W0list[11] = np.zeros(shapes[11]) #Xi #[Wmx,Wmy,Wsx,Wsy,s02,k,kappa,T,XX,XXp,Eta,Xi] # W0list = Wstar_dict['as_list'].copy() # W0list[1][1,0] = -1.5 # W0list[1][3,0] = -1.5 if init_W_from_lsq: W0list[0], W0list[1] = initialize_W(Xhat, Yhat, scale_by=scale_init_by) for ivar in range(0, 2): W0list[ivar] = W0list[ivar] + init_noise * np.random.randn( *W0list[ivar].shape) if constrain_isn: W0list[1][0, 0] = 3 W0list[1][0, 3] = 5 W0list[1][3, 0] = -5 W0list[1][3, 3] = -5 if init_W_from_file: npyfile = np.load(init_file, allow_pickle=True)[()] W0list = npyfile['as_list'] if len(W0list) < len(shapes): W0list = W0list + [np.array(0.7)] # add h2 n = 0.25 W0list[7][0] = 1 / (n + 1) * (W0list[7][0] + n * 0) # T W0list[7][3] = 1 / (n + 1) * (W0list[7][3] + n * 1) # T # wt_dict['Xi'] = 10 # wt_dict['Eta'] = 10 Wt[ihyper][itry], loss[ihyper][ itry], gr, hess, result = calnet.fitting_spatial_feature_opto_bidi.fit_W_sim( Xhat, Xpc_list, Yhat, Ypc_list, pop_rate_fn=sim_utils.f_miller_troyer, pop_deriv_fn=sim_utils.fprime_miller_troyer, neuron_rate_fn=sim_utils.evaluate_f_mt, W0list=W0list.copy(), bounds=bounds, niter=niter, wt_dict=wt_dict, l2_penalty=l2_penalty, compute_hessian=False, dt=dt, perturbation_size=perturbation_size, dYY=dYY, constrain_isn=constrain_isn, opto_mask=opto_mask) # Wt[ihyper][itry] = [w[-1] for w in Wt_temp] # loss[ihyper,itry] = loss_temp[-1] # In[285]: def parse_W(W): Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2 = W return Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2 itry = 0 Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2 = parse_W( Wt[0][0]) # In[286]: labels = [ 'Wmx', 'Wmy', 'Wsx', 'Wsy', 's02', 'K', 'kappa', 'T', 'XX', 'XXp', 'Eta', 'Xi', 'h1', 'h2' ] Wstar_dict = {} for i, label in enumerate(labels): Wstar_dict[label] = Wt[0][0][i] Wstar_dict['as_list'] = [ Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2 ] Wstar_dict['loss'] = loss[0][0] Wstar_dict['wt_dict'] = wt_dict np.save(weights_file, Wstar_dict, allow_pickle=True)
def load_cached_train_matrix(train_spec_files, train_idx, split_type, force_no_cache=False): # check if cached file exists and is legit CACHE_TRAIN_FILE = cache_file_name(split_type, len(train_idx)) print "cache filename = %s" % CACHE_TRAIN_FILE if os.path.exists(CACHE_TRAIN_FILE) and not force_no_cache: handle = open(CACHE_TRAIN_FILE, 'rb') train_idx_disk = np.load(handle) # confirm input train_idx from script matches train_idx from disk if np.all(train_idx_disk == train_idx): print "Found matching cached qso_spec_data matrix on disk! (%s)" % CACHE_TRAIN_FILE train_idx = np.load(handle) spec_grid = np.load(handle) spec_ivar_grid = np.load(handle) spec_mod_grid = np.load(handle) unique_lams = np.load(handle) spec_zs = np.load(handle) spec_ids = np.load(handle) return spec_grid, spec_ivar_grid, spec_mod_grid, unique_lams, spec_zs, spec_ids #### load the slow way :( print "cached training matrix is not the same!!! loading from spec files! (this will take a while)" spec_grid, spec_ivar_grid, spec_mod_grid, unique_lams, spec_zs, spec_ids, badids = \ ru.load_specs_from_disk(train_spec_files) with open(CACHE_TRAIN_FILE, 'wb') as handle: np.save(handle, train_idx) np.save(handle, spec_grid) np.save(handle, spec_ivar_grid) np.save(handle, spec_mod_grid) np.save(handle, unique_lams) np.save(handle, spec_zs) np.save(handle, spec_ids) return spec_grid, spec_ivar_grid, spec_mod_grid, unique_lams, spec_zs, spec_ids
weights_1, weights_2, bias_1, bias_2 = theta_reshape(thetas) # Calculate forward pass - saving results for state and rhs to array pred_state_array = np.zeros(shape=(tsteps,state_len),dtype='double') pred_state_array[0,:] = true_state_array[0,:] temp_state = np.copy(true_state_array[0,:]) for i in range(1,tsteps): time = np.reshape(time_array[i],(1,1)) output_state, output_rhs = euler_forward(temp_state,weights_1,weights_2,bias_1,bias_2,time) pred_state_array[i,:] = output_state[:] temp_state = np.copy(output_state) return pred_state_array # #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # Training # #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if deployment_mode == 'train': thetas_optimal, loss_list, val_loss_list = rms_prop_optimize(thetas) np.save('Trained_Weights.npy',thetas_optimal) np.save('Train_Loss.npy',np.asarray(loss_list)) np.save('Val_Loss.npy',np.asarray(val_loss_list)) visualize(mode=deployment_mode) else: visualize(mode='test') plt.show()
print("Loading training data...") train_images, train_labels, test_images, test_labels = loadMNIST(normalize = False) init_params = init_random_params(param_scale, layer_sizes) num_batches = int(np.ceil(len(train_images) / batch_size)) def batch_indices(iter): idx = iter % num_batches return slice(idx * batch_size, (idx+1) * batch_size) # Define training objective def objective(params, iter): idx = batch_indices(iter) return -log_posterior(params, train_images[idx], train_labels[idx], L2_reg) # Get gradient of objective using autograd. objective_grad = grad(objective) print(" Epoch | Train accuracy | Test accuracy ") def print_perf(params, iter, gradient): if iter % num_batches == 0: train_acc = accuracy(params, train_images, train_labels) test_acc = accuracy(params, test_images, test_labels) print("{:15}|{:20}|{:20}".format(iter//num_batches, train_acc, test_acc)) # The optimizers provided can optimize lists, tuples, or dicts of parameters. optimized_params = adam(objective_grad, init_params, step_size=step_size, num_iters=num_epochs * num_batches, callback=print_perf) np.save('optpara',optimized_params) print('done')
def fit_weights_and_save( weights_file, ca_data_file='rs_vm_denoise_200605.npy', opto_silencing_data_file='vip_halo_data_for_sim.npy', opto_activation_data_file='vip_chrimson_data_for_sim.npy', constrain_wts=None, allow_var=True, fit_s02=True, constrain_isn=True, tv=False, l2_penalty=0.01, init_noise=0.1, init_W_from_lsq=False, scale_init_by=1, init_W_from_file=False, init_file=None, correct_Eta=False, init_Eta_with_s02=False, init_Eta12_with_dYY=False, use_opto_transforms=False, share_residuals=False, stimwise=False, simulate1=True, simulate2=False, help_constrain_isn=True, ignore_halo_vip=False, verbose=True, free_amplitude=False, norm_opto_transforms=False, zero_extra_weights=None, no_halo_res=False, l23_as_l4=False): nsize, ncontrast = 6, 6 npfile = np.load(ca_data_file, allow_pickle=True)[( )] #,{'rs':rs,'rs_denoise':rs_denoise},allow_pickle=True) rs = npfile['rs'] #rs_denoise = npfile['rs_denoise'] nsize, ncontrast, ndir = 6, 6, 8 #ori_dirs = [[0,4],[2,6]] #[[0,4],[1,3,5,7],[2,6]] ori_dirs = [[0, 1, 2, 3, 4, 5, 6, 7]] nT = len(ori_dirs) nS = len(rs[0]) def sum_to_1(r): R = r.reshape((r.shape[0], -1)) #R = R/np.nansum(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis] R = R / np.nansum(R, axis=1)[:, np.newaxis] # changed 8/28 return R def norm_to_mean(r): R = r.reshape((r.shape[0], -1)) R = R / np.nanmean(R[:, ~np.isnan(R.sum(0))], axis=1)[:, np.newaxis] return R Rs = [[None, None] for i in range(len(rs))] Rso = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(rs))] rso = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(rs))] for iR, r in enumerate(rs): #rs_denoise): #print(iR) for ialign in range(nS): #Rs[iR][ialign] = r[ialign][:,:nsize,:] #sm = np.nanmean(np.nansum(np.nansum(Rs[iR][ialign],1),1)) #Rs[iR][ialign] = Rs[iR][ialign]/sm #print('frac isnan Rs %d,%d: %f'%(iR,ialign,np.isnan(r[ialign]).mean())) Rs[iR][ialign] = sum_to_1(r[ialign][:, :nsize, :]) # Rs[iR][ialign] = von_mises_denoise(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))) kernel = np.ones((1, 2, 2)) kernel = kernel / kernel.sum() for iR, r in enumerate(rs): for ialign in range(nS): for iori in range(nT): #print('this Rs shape: '+str(Rs[iR][ialign].shape)) #print('this Rs reshaped shape: '+str(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]].shape)) #print('this Rs max percent nan: '+str(np.isnan(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]]).mean(-1).max())) Rso[iR][ialign][iori] = np.nanmean( Rs[iR][ialign].reshape( (-1, nsize, ncontrast, ndir))[:, :, :, ori_dirs[iori]], -1) Rso[iR][ialign][iori][:, :, 0] = np.nanmean( Rso[iR][ialign][iori][:, :, 0], 1)[:, np.newaxis] # average 0 contrast values #print('frac isnan pre-conv Rso %d,%d,%d: %f'%(iR,ialign,iori,np.isnan(Rso[iR][ialign][iori]).mean())) Rso[iR][ialign][iori][:, 1:, 1:] = ssi.convolve( Rso[iR][ialign][iori], kernel, 'valid') Rso[iR][ialign][iori] = Rso[iR][ialign][iori].reshape( Rso[iR][ialign][iori].shape[0], -1) #print('frac isnan Rso %d,%d,%d: %f'%(iR,ialign,iori,np.isnan(Rso[iR][ialign][iori]).mean())) #print('sum of Rso isnan: '+str(np.isnan(Rso[iR][ialign][iori]).sum(1))) #Rso[iR][ialign][iori] = Rso[iR][ialign][iori]/np.nanmean(Rso[iR][ialign][iori],-1)[:,np.newaxis] def set_bound(bd, code, val=0): # set bounds to 0 where 0s occur in 'code' for iitem in range(len(bd)): bd[iitem][code[iitem]] = val nN = 36 nS = 2 nP = 2 nT = 1 nQ = 4 # code for bounds: 0 , constrained to 0 # +/-1 , constrained to +/-1 # 1.5, constrained to [0,1] # 2 , constrained to [0,inf) # -2 , constrained to (-inf,0] # 3 , unconstrained Wmx_bounds = 3 * np.ones((nP, nQ), dtype=int) Wmx_bounds[0, :] = 2 # L4 PCs are excitatory Wmx_bounds[0, 1] = 0 # SSTs don't receive L4 input if allow_var: Wsx_bounds = 3 * np.ones( Wmx_bounds.shape) #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds) Wsx_bounds[0, 1] = 0 else: Wsx_bounds = np.zeros( Wmx_bounds.shape) #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds) Wmy_bounds = 3 * np.ones((nQ, nQ), dtype=int) Wmy_bounds[0, :] = 2 # PCs are excitatory Wmy_bounds[1:, :] = -2 # all the cell types except PCs are inhibitory Wmy_bounds[1, 1] = 0 # SSTs don't inhibit themselves # Wmy_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al. Wmy_bounds[ 2, 0] = 0 # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition if not zero_extra_weights is None: Wmx_bounds[zero_extra_weights[0]] = 0 Wmy_bounds[zero_extra_weights[1]] = 0 if allow_var: Wsy_bounds = 3 * np.ones( Wmy_bounds.shape) #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds) Wsy_bounds[1, 1] = 0 Wsy_bounds[3, 1] = 0 Wsy_bounds[2, 0] = 0 else: Wsy_bounds = np.zeros( Wmy_bounds.shape) #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds) if not constrain_wts is None: for wt in constrain_wts: Wmy_bounds[wt[0], wt[1]] = 0 Wsy_bounds[wt[0], wt[1]] = 0 def tile_nS_nT_nN(kernel): row = np.concatenate([kernel for idim in range(nS * nT)], axis=0)[np.newaxis, :] tiled = np.concatenate([row for irow in range(nN)], axis=0) return tiled def set_bounds_by_code(lb, ub, bdlist): set_bound(lb, [bd == 0 for bd in bdlist], val=0) set_bound(ub, [bd == 0 for bd in bdlist], val=0) set_bound(lb, [bd == 2 for bd in bdlist], val=0) set_bound(ub, [bd == -2 for bd in bdlist], val=0) set_bound(lb, [bd == 1 for bd in bdlist], val=1) set_bound(ub, [bd == 1 for bd in bdlist], val=1) set_bound(lb, [bd == 1.5 for bd in bdlist], val=0) set_bound(ub, [bd == 1.5 for bd in bdlist], val=1) set_bound(lb, [bd == -1 for bd in bdlist], val=-1) set_bound(ub, [bd == -1 for bd in bdlist], val=-1) if fit_s02: s02_bounds = 2 * np.ones( (nQ, )) # permitting noise as a free parameter else: s02_bounds = np.ones((nQ, )) k_bounds = 1.5 * np.ones((nQ * (nS - 1), )) #k_bounds[1] = 0 # temporary: spatial kernel constrained to 0 for SST #k_bounds[2] = 0 # temporary: spatial kernel constrained to 0 for VIP kappa_bounds = np.ones((1, )) # kappa_bounds = 2*np.ones((1,)) T_bounds = 1.5 * np.ones((nQ * (nT - 1), )) X_bounds = tile_nS_nT_nN(np.array([2, 1])) # X_bounds = np.array([np.array([2,1,2,1])]*nN) Xp_bounds = tile_nS_nT_nN(np.array([3, 1])) # Xp_bounds = np.array([np.array([3,1,3,1])]*nN) # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,))) # # Y_bounds = 2*np.ones((nN,nT*nS*nQ)) Eta_bounds = tile_nS_nT_nN(3 * np.ones((nQ, ))) # Eta_bounds = 3*np.ones((nN,nT*nS*nQ)) if allow_var: Xi_bounds = tile_nS_nT_nN(3 * np.ones((nQ, ))) else: Xi_bounds = tile_nS_nT_nN(np.zeros((nQ, ))) # Xi_bounds = 3*np.ones((nN,nT*nS*nQ)) h1_bounds = -2 * np.ones((1, )) h2_bounds = 2 * np.ones((1, )) bl_bounds = 3 * np.ones((nQ, )) if free_amplitude: amp_bounds = 2 * np.ones((nT * nS * nQ, )) else: amp_bounds = 1 * np.ones((nT * nS * nQ, )) # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)] shapes1 = [(nP, nQ), (nQ, nQ), (nP, nQ), (nQ, nQ), (nQ, ), (nQ * (nS - 1), ), (1, ), (nQ * (nT - 1), ), (1, ), (1, ), (nQ, ), (nQ * nS * nT, )] shapes2 = [(nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ), (nN, nT * nS * nQ), (nN, nT * nS * nP), (nN, nT * nS * nP)] #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1]))) #print('size of shapes2: '+str(np.sum([np.prod(shp) for shp in shapes2]))) # Wmx, Wmy, Wsx, Wsy, s02, k, kappa,T, h1, h2 #XX, XXp, Eta, Xi #bdlist = [Wmx_bounds,Wmy_bounds,Wsx_bounds,Wsy_bounds,s02_bounds,k_bounds,kappa_bounds,T_bounds,X_bounds,Xp_bounds,Eta_bounds,Xi_bounds,h1_bounds,h2_bounds] bd1list = [ Wmx_bounds, Wmy_bounds, Wsx_bounds, Wsy_bounds, s02_bounds, k_bounds, kappa_bounds, T_bounds, h1_bounds, h2_bounds, bl_bounds, amp_bounds ] bd2list = [X_bounds, Xp_bounds, Eta_bounds, Xi_bounds, X_bounds, X_bounds] lb1, ub1 = [[sgn * np.inf * np.ones(shp) for shp in shapes1] for sgn in [-1, 1]] set_bounds_by_code(lb1, ub1, bd1list) lb2, ub2 = [[sgn * np.inf * np.ones(shp) for shp in shapes2] for sgn in [-1, 1]] set_bounds_by_code(lb2, ub2, bd2list) #set_bound(lb,[bd==0 for bd in bdlist],val=0) #set_bound(ub,[bd==0 for bd in bdlist],val=0) # #set_bound(lb,[bd==2 for bd in bdlist],val=0) # #set_bound(ub,[bd==-2 for bd in bdlist],val=0) # #set_bound(lb,[bd==1 for bd in bdlist],val=1) #set_bound(ub,[bd==1 for bd in bdlist],val=1) # #set_bound(lb,[bd==1.5 for bd in bdlist],val=0) #set_bound(ub,[bd==1.5 for bd in bdlist],val=1) # #set_bound(lb,[bd==-1 for bd in bdlist],val=-1) #set_bound(ub,[bd==-1 for bd in bdlist],val=-1) # for bd in [lb,ub]: # for ind in [2,3]: # bd[ind][:,1] = 0 # temporary for no variation expt. # lb[2] = np.zeros_like(lb[2]) # lb[3] = np.zeros_like(lb[3]) # lb[4] = np.ones_like(lb[4]) # lb[5] = np.zeros_like(lb[5]) # ub[2] = np.zeros_like(ub[2]) # ub[3] = np.zeros_like(ub[3]) # ub[4] = np.ones_like(ub[4]) # ub[5] = np.ones_like(ub[5]) # temporary for no variation expt. lb1 = np.concatenate([a.flatten() for a in lb1]) ub1 = np.concatenate([b.flatten() for b in ub1]) lb2 = np.concatenate([a.flatten() for a in lb2]) ub2 = np.concatenate([b.flatten() for b in ub2]) bounds1 = [(a, b) for a, b in zip(lb1, ub1)] bounds2 = [(a, b) for a, b in zip(lb2, ub2)] nS = 2 #print('nT: '+str(nT)) ndims = 5 ncelltypes = 5 Yhat = [[None for iT in range(nT)] for iS in range(nS)] Xhat = [[None for iT in range(nT)] for iS in range(nS)] Ypc_list = [[None for iT in range(nT)] for iS in range(nS)] Xpc_list = [[None for iT in range(nT)] for iS in range(nS)] mx = [None for iS in range(nS)] for iS in range(nS): mx[iS] = np.zeros((ncelltypes, )) yy = [None for icelltype in range(ncelltypes)] for icelltype in range(ncelltypes): yy[icelltype] = np.nanmean(Rso[icelltype][iS][0], 0) mx[iS][icelltype] = np.nanmax(yy[icelltype]) for iT in range(nT): y = [ np.nanmean(Rso[icelltype][iS][iT], axis=0)[:, np.newaxis] / mx[iS][icelltype] for icelltype in range(1, ncelltypes) ] Ypc_list[iS][iT] = [None for icelltype in range(1, ncelltypes)] for icelltype in range(1, ncelltypes): rss = Rso[icelltype][iS][iT].copy( ) #/mx[iS][icelltype] #.reshape(Rs[icelltype][ialign].shape[0],-1) #print('sum of isnan: '+str(np.isnan(rss).sum(1))) #rss = Rso[icelltype][iS][iT].copy() #.reshape(Rs[icelltype][ialign].shape[0],-1) rss = rss[np.isnan(rss).sum(1) == 0] # print(rss.max()) # rss[rss<0] = 0 # rss = rss[np.random.randn(rss.shape[0])>0] try: u, s, v = np.linalg.svd(rss - np.mean(rss, 0)[np.newaxis]) Ypc_list[iS][iT][icelltype - 1] = [ (s[idim], v[idim]) for idim in range(ndims) ] # print('yep on Y') # print(np.min(np.sum(rs[icelltype][iS][iT],axis=1))) except: print('nope on Y') #print('shape of rss: '+str(rss.shape)) #print('mean of rss: '+str(np.mean(np.isnan(rss)))) #print('min of this rs: '+str(np.min(np.sum(rs[icelltype][iS][iT],axis=1)))) Yhat[iS][iT] = np.concatenate(y, axis=1) # x = sim_utils.columnize(Rso[0][iS][iT])[:,np.newaxis] icelltype = 0 #x = np.nanmean(Rso[icelltype][iS][iT],0)[:,np.newaxis]#/mx[iS][icelltype] x = np.nanmean(Rso[icelltype][iS][iT], 0)[:, np.newaxis] / mx[iS][icelltype] # opto_column = np.concatenate((np.zeros((nN,)),np.zeros((nNO/2,)),np.ones((nNO/2,))),axis=0)[:,np.newaxis] Xhat[iS][iT] = np.concatenate((x, np.ones_like(x)), axis=1) # Xhat[iS][iT] = np.concatenate((x,np.ones_like(x),opto_column),axis=1) icelltype = 0 #rss = Rso[icelltype][iS][iT].copy()/mx[iS][icelltype] rss = Rso[icelltype][iS][iT].copy() rss = rss[np.isnan(rss).sum(1) == 0] # try: u, s, v = np.linalg.svd(rss - rss.mean(0)[np.newaxis]) Xpc_list[iS][iT] = [None for iinput in range(2)] Xpc_list[iS][iT][0] = [(s[idim], v[idim]) for idim in range(ndims)] Xpc_list[iS][iT][1] = [(0, np.zeros((Xhat[0][0].shape[0], ))) for idim in range(ndims)] # except: # print('nope on X') # print(np.mean(np.isnan(rss))) # print(np.min(np.sum(Rso[icelltype][iS][iT],axis=1))) nN, nP = Xhat[0][0].shape #print('nP: '+str(nP)) nQ = Yhat[0][0].shape[1] def compute_f_(Eta, Xi, s02): return sim_utils.f_miller_troyer( Eta, Xi**2 + np.concatenate([s02 for ipixel in range(nS * nT)])) def compute_fprime_m_(Eta, Xi, s02): return sim_utils.fprime_miller_troyer( Eta, Xi**2 + np.concatenate([s02 for ipixel in range(nS * nT)])) * Xi def compute_fprime_s_(Eta, Xi, s02): s2 = Xi**2 + np.concatenate((s02, s02), axis=0) return sim_utils.fprime_s_miller_troyer(Eta, s2) * (Xi / s2) def sorted_r_eigs(w): drW, prW = np.linalg.eig(w) srtinds = np.argsort(drW) return drW[srtinds], prW[:, srtinds] # 0.Wmx, 1.Wmy, 2.Wsx, 3.Wsy, 4.s02,5.K, 6.kappa,7.T,8.XX, 9.XXp, 10.Eta, 11.Xi, 12.h1, 13.h2 shapes1 = [(nP, nQ), (nQ, nQ), (nP, nQ), (nQ, nQ), (nQ, ), (nQ * (nS - 1), ), (1, ), (nQ * (nT - 1), ), (1, ), (1, ), (nQ, ), (nT * nS * nQ, )] shapes2 = [(nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ), (nN, nT * nS * nQ), (nN, nT * nS * nP), (nN, nT * nS * nP)] #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1]))) #print('size of shapes2: '+str(np.sum([np.prod(shp) for shp in shapes2]))) import calnet.fitting_spatial_feature import sim_utils YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat) XXhat = calnet.utils.flatten_nested_list_of_2d_arrays(Xhat) opto_dict = np.load(opto_silencing_data_file, allow_pickle=True)[()] Yhat_opto = opto_dict['Yhat_opto'] Yhat_opto = np.nanmean(np.reshape(Yhat_opto, (nN, 2, nS, 2, nQ)), 3).reshape((nN * 2, -1)) Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12], axis=0)[np.newaxis] Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12], axis=0)[np.newaxis] Yhat_opto = Yhat_opto / np.nanmax(Yhat_opto[0::2], 0)[np.newaxis, :] #print(Yhat_opto.shape) h_opto = opto_dict['h_opto'] #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2] Xhat_opto = opto_dict['Xhat_opto'] Xhat_opto = np.nanmean(np.reshape(Xhat_opto, (nN, 2, nS, 2, nP)), 3).reshape((nN * 2, -1)) Xhat_opto[0::12] = np.nanmean(Xhat_opto[0::12], axis=0)[np.newaxis] Xhat_opto[1::12] = np.nanmean(Xhat_opto[1::12], axis=0)[np.newaxis] Xhat_opto = Xhat_opto / np.nanmax(Xhat_opto[0::2], 0)[np.newaxis, :] YYhat_halo = Yhat_opto.reshape((nN, 2, -1)) opto_transform1 = calnet.utils.fit_opto_transform( YYhat_halo, norm01=norm_opto_transforms) if l23_as_l4: Xhat_opto[:, 0::2] = Yhat_opto[:, 0::4] Xhat_halo = Xhat_opto.reshape((nN, 2, -1)) opto_transform1x = calnet.utils.fit_opto_transform( Xhat_halo, norm01=norm_opto_transforms) if no_halo_res: opto_transform1.res[:, [0, 2, 3, 4, 6, 7]] = 0 opto_transform1x.res[:, :] = 0 dYY1 = opto_transform1.transform(YYhat) - opto_transform1.preprocess(YYhat) dXX1 = opto_transform1x.transform(XXhat) - opto_transform1x.preprocess( XXhat) print('delta bias: %f' % dXX1[:, 1].mean()) #YYhat_halo_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_halo) #dYY1 = YYhat_halo_sim[:,1,:] - YYhat_halo_sim[:,0,:] def overwrite_plus_n(arr, to_overwrite, n): arr[:, to_overwrite] = arr[:, int(to_overwrite + n)] return arr for to_overwrite in [1, 2]: n = 4 dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \ = [overwrite_plus_n(x,to_overwrite,n) for x in \ [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]] for to_overwrite in [7]: n = -4 dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \ = [overwrite_plus_n(x,to_overwrite,n) for x in \ [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]] for to_overwrite in [2]: n = -2 dXX1,opto_transform1x.slope,opto_transform1x.intercept,opto_transform1x.res \ = [overwrite_plus_n(x,to_overwrite,n) for x in \ [dXX1,opto_transform1x.slope,opto_transform1x.intercept,opto_transform1x.res]] if ignore_halo_vip: dYY1[:, 2::nQ] = np.nan #for to_overwrite in [1,2]: # dYY1[:,to_overwrite] = dYY1[:,to_overwrite+4] #for to_overwrite in [7]: # dYY1[:,to_overwrite] = dYY1[:,to_overwrite-4] #Yhat_opto = opto_dict['Yhat_opto'] #for iS in range(nS): # mx = np.zeros((nQ,)) # for iQ in range(nQ): # slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ) # mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer]) # Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ] ##Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:] #print(Yhat_opto.shape) #h_opto = opto_dict['h_opto'] #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2] #for to_overwrite in [1,2,5,6]: # overwrite sst and vip with off-centered values # dYY1[:,to_overwrite] = dYY1[:,to_overwrite+8] #for to_overwrite in [11,15]: # dYY1[:,to_overwrite] = np.nan #dYY1[:,to_overwrite-8] opto_dict = np.load(opto_activation_data_file, allow_pickle=True)[()] Yhat_opto = opto_dict['Yhat_opto'] Yhat_opto = np.nanmean(np.reshape(Yhat_opto, (nN, 2, nS, 2, nQ)), 3).reshape((nN * 2, -1)) Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12], axis=0)[np.newaxis] Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12], axis=0)[np.newaxis] Yhat_opto = Yhat_opto / Yhat_opto[0::2].max(0)[np.newaxis, :] #print(Yhat_opto.shape) h_opto = opto_dict['h_opto'] #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2] YYhat_chrimson = Yhat_opto.reshape((nN, 2, -1)) opto_transform2 = calnet.utils.fit_opto_transform( YYhat_chrimson, norm01=norm_opto_transforms) Xhat_opto = np.nan * np.ones((Yhat_opto.shape[0], nP * nS * nT)) Xhat_opto[:, 1::2] = 1 if l23_as_l4: Xhat_opto[:, 0::2] = Yhat_opto[:, 0::4] Xhat_chrimson = Xhat_opto.reshape((nN, 2, -1)) opto_transform2x = calnet.utils.fit_opto_transform( Xhat_chrimson, norm01=norm_opto_transforms) dYY2 = opto_transform2.transform(YYhat) - opto_transform2.preprocess(YYhat) dXX2 = opto_transform2x.transform(XXhat) - opto_transform2x.preprocess( XXhat) #YYhat_chrimson_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_chrimson) #dYY2 = YYhat_chrimson_sim[:,1,:] - YYhat_chrimson_sim[:,0,:] #Yhat_opto = opto_dict['Yhat_opto'] #for iS in range(nS): # mx = np.zeros((nQ,)) # for iQ in range(nQ): # slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ) # mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer]) # Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ] ##Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:] #print(Yhat_opto.shape) #h_opto = opto_dict['h_opto'] #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2] #print('dYY1 mean: %03f'%np.nanmean(np.abs(dYY1))) #print('dYY2 mean: %03f'%np.nanmean(np.abs(dYY2))) dYY = np.concatenate((dYY1, dYY2), axis=0) dXX = np.concatenate((dXX1, dXX2), axis=0) #titles = ['VIP silencing','VIP activation'] #for itype in [0,1,2,3]: # plt.figure(figsize=(5,2.5)) # for iyy,dyy in enumerate([dYY1,dYY2]): # plt.subplot(1,2,iyy+1) # if np.sum(np.isnan(dyy[:,itype]))==0: # sca.scatter_size_contrast(YYhat[:,itype],YYhat[:,itype]+dyy[:,itype],nsize=6,ncontrast=6)#,mn=0) # plt.title(titles[iyy]) # plt.xlabel('cell type %d event rate, \n light off'%itype) # plt.ylabel('cell type %d event rate, \n light on'%itype) # ut.erase_top_right() # plt.tight_layout() # ut.mkdir('figures') # plt.savefig('figures/scatter_light_on_light_off_target_celltype_%d.eps'%itype) opto_mask = ~np.isnan(dYY) opto_maskX = ~np.isnan(dXX) #dYY[nN:][~opto_mask[nN:]] = -dYY[:nN][~opto_mask[nN:]] #print('mean of opto_mask: '+str(opto_mask.mean())) #dYY[~opto_mask] = 0 def zero_nans(arr): arr[np.isnan(arr)] = 0 return arr #dYY,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res,\ # opto_transform2.slope,opto_transform2.intercept,opto_transform2.res\ # = [zero_nans(x) for x in \ # [dYY,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res,\ # opto_transform2.slope,opto_transform2.intercept,opto_transform2.res]] dYY = zero_nans(dYY) dXX = zero_nans(dXX) # for cell types that were not measured with chrimson, fill with values inferred from halo data (this shouldn't matter, as these entries are masked by opto_mask) to_adjust = np.logical_or(np.isnan(opto_transform2.slope[0]), np.isnan(opto_transform2.intercept[0])) opto_transform2.slope[:, to_adjust] = 1 / opto_transform1.slope[:, to_adjust] opto_transform2.intercept[:, to_adjust] = -opto_transform1.intercept[:, to_adjust] / opto_transform1.slope[:, to_adjust] opto_transform2.res[:, to_adjust] = -opto_transform1.res[:, to_adjust] / opto_transform1.slope[:, to_adjust] #np.save('/Users/dan/Documents/notebooks/mossing-PC/shared_data/calnet_data/dYY.npy',dYY) from importlib import reload reload(calnet) #reload(calnet.fitting_2step_spatial_feature_opto_tight_nonlinear) reload(sim_utils) # reload(calnet.fitting_spatial_feature) # W0list = [np.ones(shp) for shp in shapes] wt_dict = {} wt_dict['X'] = 1 wt_dict['Y'] = 3 wt_dict['Eta'] = 3 # 1 # wt_dict['Xi'] = 0.1 wt_dict['stims'] = np.ones((nN, 1)) #(np.arange(30)/30)[:,np.newaxis]**1 # wt_dict['barrier'] = 0. #30.0 #0.1 wt_dict['opto'] = 1 #1e1 wt_dict['isn'] = 0.3 wt_dict['tv'] = 1 spont_frac = 0.5 pc_frac = 0.5 wt_dict['stimsOpto'] = (1 - spont_frac) * 6 / 5 * np.ones((nN, 1)) wt_dict['stimsOpto'][0::6] = spont_frac * 6 wt_dict['celltypesOpto'] = (1 - pc_frac) * 4 / 3 * np.ones( (1, nQ * nS * nT)) wt_dict['celltypesOpto'][0, 0::nQ] = pc_frac * 4 wt_dict['dirOpto'] = np.array((1, 0.3)) wt_dict['dYY'] = 10 #10 wt_dict['dXX'] = 10 #10 wt_dict['coupling'] = 1e-3 wt_dict['smi'] = 0.1 wt_dict['smi_halo'] = 30 wt_dict['smi_chrimson'] = 0.1 ##temporary no_opto wt_dict['opto'] = 0.01 #0 wt_dict['dirOpto'] = np.array((1, 1)) wt_dict['stimsOpto'] = np.ones((nN, 1)) wt_dict['celltypesOpto'] = np.ones((1, nQ * nS * nT)) wt_dict['smi'] = 0 #0.01 # 0 wt_dict['smi_halo'] = 0 #1 # 0 wt_dict['smi_chrimson'] = 0 #0.01 # 0 wt_dict['isn'] = 0.1 wt_dict['tv'] = 0.1 wt_dict['X'] = 3 #wt_dict['Eta'] = 300 # 1 # ## temporary opto from no_opto #wt_dict['opto'] = 0.01 #wt_dict['tv'] = 0.3#0.1 np.save( 'XXYYhat.npy', { 'YYhat': YYhat, 'XXhat': XXhat, 'rs': rs, 'Rs': Rs, 'Rso': Rso, 'Ypc_list': Ypc_list, 'Xpc_list': Xpc_list }) Eta0 = invert_f_mt(YYhat) # Wmx, Wmy, Wsx, Wsy, s02, k, kappa,T, h1, h2 #XX, XXp, Eta, Xi, XX1, XX2 ntries = 1 nhyper = 1 dt = 1e-1 niter = int(np.round(10 / dt)) #int(1e4) perturbation_size = 5e-2 # learning_rate = 1e-4 # 1e-5 #np.linspace(3e-4,1e-3,niter+1) # 1e-5 #l2_penalty = 0.1 W1t = [[None for itry in range(ntries)] for ihyper in range(nhyper)] W2t = [[None for itry in range(ntries)] for ihyper in range(nhyper)] loss = np.zeros((nhyper, ntries)) is_neg = np.array([b[1] for b in bounds1]) == 0 counter = 0 negatize = [np.zeros(shp, dtype='bool') for shp in shapes1] #print(shapes1) for ishp, shp in enumerate(shapes1): nel = np.prod(shp) negatize[ishp][:][is_neg[counter:counter + nel].reshape(shp)] = True counter = counter + nel for ihyper in range(nhyper): for itry in range(ntries): #print((ihyper,itry)) W10list = [ init_noise * (ihyper + 1) * np.random.rand(*shp) for shp in shapes1 ] W20list = [ init_noise * (ihyper + 1) * np.random.rand(*shp) for shp in shapes2 ] #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1]))) #print('size of w10: '+str(np.sum([np.size(x) for x in W10list]))) #print('len(W10list) : '+str(len(W10list))) counter = 0 for ishp, shp in enumerate(shapes1): W10list[ishp][negatize[ishp]] = -W10list[ishp][negatize[ishp]] W10list[4] = np.ones(shapes1[4]) # s02 W10list[5] = np.ones(shapes1[5]) # K W10list[6] = np.ones(shapes1[6]) # kappa W10list[7] = np.ones(shapes1[7]) # T W20list[0] = np.concatenate(Xhat, axis=1) #XX W20list[1] = np.zeros_like(W20list[1]) #XXp W20list[2] = Eta0.copy() #np.zeros(shapes[10]) #Eta W20list[3] = np.zeros(shapes2[3]) #Xi W20list[4] = np.concatenate(Xhat, axis=1) #XX W20list[5] = np.concatenate(Xhat, axis=1) #XX #[Wmx,Wmy,Wsx,Wsy,s02,k,kappa,T,XX,XXp,Eta,Xi] if init_W_from_lsq: W10list[0], W10list[1] = initialize_W(Xhat, Yhat, scale_by=scale_init_by) for ivar in range(0, 2): W10list[ ivar] = W10list[ivar] + init_noise * np.random.randn( *W10list[ivar].shape) if constrain_isn: W10list[1][0, 0] = 3 if help_constrain_isn: W10list[1][0, 3] = 5 W10list[1][3, 0] = -5 W10list[1][3, 3] = -5 else: W10list[1][0, 1:4] = 5 W10list[1][1:4, 0] = -5 if init_W_from_file: npyfile = np.load(init_file, allow_pickle=True)[()] #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1) #XX,XXp,Eta,Xi,XX1,XX2 = parse_W2(W2) #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,XX1,XX2,h1,h2,bl,amp = parse_W1(W1) if len(npyfile['as_list']) == 18: W10list = [ npyfile['as_list'][ivar] for ivar in [0, 1, 2, 3, 4, 5, 6, 7, 14, 15, 16, 17] ] W20list = [ npyfile['as_list'][ivar] for ivar in [8, 9, 10, 11, 12, 13] ] elif len(npyfile['as_list']) == 16: W10list = [ npyfile['as_list'][ivar] for ivar in [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] ] W20list = [ npyfile['as_list'][ivar] for ivar in [8, 9, 10, 11, 8, 8] ] if W20list[0].size == nN * nS * 2 * nP: #assert(True==False) W10list[7] = np.array(()) W10list[1][1, 0] = W10list[1][1, 0] W20list[0] = np.nanmean( W20list[0].reshape((nN, nS, 2, nP)), 2).flatten() #XX W20list[1] = np.nanmean( W20list[1].reshape((nN, nS, 2, nP)), 2).flatten() #XXp W20list[2] = np.nanmean( W20list[2].reshape((nN, nS, 2, nQ)), 2).flatten() #Eta W20list[3] = np.nanmean( W20list[3].reshape((nN, nS, 2, nQ)), 2).flatten() #Xi W20list[4] = np.nanmean( W20list[4].reshape((nN, nS, 2, nP)), 2).flatten() #XX1 W20list[5] = np.nanmean( W20list[5].reshape((nN, nS, 2, nP)), 2).flatten() #XX2 if correct_Eta: #assert(True==False) W20list[2] = Eta0.copy() if len(W10list) < len(shapes1): #assert(True==False) W10list = W10list + [ np.array(1), np.zeros((nQ, )), np.zeros((nT * nS * nQ, )) ] # add h2, bl, amp if init_Eta_with_s02: #assert(True==False) s02 = W10list[4].copy() Eta0 = invert_f_mt_with_s02(YYhat, s02, nS=nS, nT=nT) W20list[2] = Eta0.copy() #if init_Eta12_with_dYY: # Eta0 = W20list[2].copy().reshape((nN,nQ*nS*nT)) # Xi0 = W20list[3].copy().reshape((nN,nQ*nS*nT)) # s020 = W10list[4].copy() # YY0s = compute_f_(Eta0,Xi0,s020) #titles = ['VIP silencing','VIP activation'] #for itype in [0,1,2,3]: # plt.figure(figsize=(5,2.5)) # for iyy,yy in enumerate([YY10s,YY20s]): # plt.subplot(1,2,iyy+1) # if np.sum(np.isnan(yy[:,itype]))==0: # sca.scatter_size_contrast(YY0s[:,itype],yy[:,itype],nsize=6,ncontrast=6)#,mn=0) # plt.title(titles[iyy]) # plt.xlabel('cell type %d event rate, \n light off'%itype) # plt.ylabel('cell type %d event rate, \n light on'%itype) # ut.erase_top_right() # plt.tight_layout() # ut.mkdir('figures') # plt.savefig('figures/scatter_light_on_light_off_init_celltype_%d.eps'%itype) for ivar in [0, 1, 4, 5]: # Wmx, Wmy, s02, k print(init_noise) W10list[ ivar] = W10list[ivar] + init_noise * np.random.randn( *W10list[ivar].shape) #print('size of bounds1: '+str(np.sum([np.size(x) for x in bd1list]))) #print('size of w10: '+str(np.sum([np.size(x) for x in W10list]))) #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1]))) W1t[ihyper][itry], W2t[ihyper][itry], loss[ihyper][ itry], gr, hess, result = calnet.fitting_2step_opto_layers.fit_W_sim( Xhat, Xpc_list, Yhat, Ypc_list, pop_rate_fn=sim_utils.f_miller_troyer, pop_deriv_fn=sim_utils.fprime_miller_troyer, neuron_rate_fn=sim_utils.evaluate_f_mt, W10list=W10list.copy(), W20list=W20list.copy(), bounds1=bounds1, bounds2=bounds2, niter=niter, wt_dict=wt_dict, l2_penalty=l2_penalty, compute_hessian=False, dt=dt, perturbation_size=perturbation_size, dYY=dYY, dXX=dXX, constrain_isn=constrain_isn, tv=tv, opto_mask=opto_mask, opto_maskX=opto_maskX, use_opto_transforms=use_opto_transforms, opto_transform1=opto_transform1, opto_transform1x=opto_transform1x, opto_transform2=opto_transform2, opto_transform2x=opto_transform2x, share_residuals=share_residuals, stimwise=stimwise, simulate1=simulate1, simulate2=simulate2, verbose=verbose) #def parse_W(W): # Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = W # return Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 def parse_W1(W): Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp = W return Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp def parse_W2(W): XX, XXp, Eta, Xi, XX1, XX2 = W return XX, XXp, Eta, Xi, XX1, XX2 itry = 0 Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp = parse_W1(W1t[0][0]) XX, XXp, Eta, Xi, XX1, XX2 = parse_W2(W2t[0][0]) labels1 = [ 'Wmx', 'Wmy', 'Wsx', 'Wsy', 's02', 'K', 'kappa', 'T', 'h1', 'h2', 'bl', 'amp' ] labels2 = ['XX', 'XXp', 'Eta', 'Xi', 'XX1', 'XX2'] Wstar_dict = {} for i, label in enumerate(labels1): Wstar_dict[label] = W1t[0][0][i] for i, label in enumerate(labels2): Wstar_dict[label] = W2t[0][0][i] Wstar_dict['as_list'] = [ Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, XX1, XX2, h1, h2, bl, amp ] Wstar_dict['loss'] = loss[0][0] Wstar_dict['wt_dict'] = wt_dict np.save(weights_file, Wstar_dict, allow_pickle=True)
# #mfvi_params = mfvi_init() #lrd_params = init_single_component(rank=0) #lrd_params = init_single_component(rank=3) ########################### # Variational Boosting # ########################### if args.vboost: # initialize a single component of appropriate rank and cache vi_params, infobj = init_single_component(rank=args.rank) init_file = os.path.join(args.output, "initial_component-rank_%d.npy" % args.rank) np.save(init_file, vi_params) # initialize LRD component (works with MixtureVI ...) comp = LRDComponent(D, rank=args.rank) m, d, C = vi_params[:D], vi_params[D:(2 * D)], vi_params[2 * D:] comp.lam = comp.setter(vi_params.copy(), mean=m, v=d, C=C) # Variational Boosting Object (with initial component from vbproj.vi.vboost import mog_bbvi vbobj = vi.MixtureVI(lambda z, t: lnpdf(z), D=D, comp_list=[(1., comp)], fix_component_samples=True, break_condition='percent') # iteratively add comps
def experiment(sname, seed, datasize, nystr=False): def LMO_err(params, M=2): al, bl = np.exp(params) L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN if nystr: tmp_mat = L @ eig_vec_K C = L - tmp_mat @ np.linalg.inv(eig_vec_K.T @ tmp_mat / N2 + inv_eig_val_K) @ tmp_mat.T / N2 c = C @ W_nystr_Y * N2 else: LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN) C = L @ LWL_inv @ L / N2 c = C @ W @ Y * N2 c_y = c - Y lmo_err = 0 N = 0 for ii in range(1): permutation = np.random.permutation(X.shape[0]) for i in range(0, X.shape[0], M): indices = permutation[i:i + M] K_i = W[np.ix_(indices, indices)] * N2 C_i = C[np.ix_(indices, indices)] c_y_i = c_y[indices] b_y = np.linalg.inv(np.eye(M) - C_i @ K_i) @ c_y_i lmo_err += b_y.T @ K_i @ b_y N += 1 return lmo_err[0, 0] / N / M**2 def callback0(params, timer=None): global Nfeval, prev_norm, opt_params, opt_test_err if Nfeval % 1 == 0: al, bl = np.exp(params) L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN if nystr: alpha = EYEN - eig_vec_K @ np.linalg.inv( eig_vec_K.T @ L @ eig_vec_K / N2 + np.diag(1 / eig_val_K / N2)) @ eig_vec_K.T @ L / N2 alpha = alpha @ W_nystr @ Y * N2 else: LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN) alpha = LWL_inv @ L @ W @ Y # L_W_inv = chol_inv(W*N2+L_inv) test_L = bl * bl * np.exp(-test_L0 / al / al / 2) pred_mean = test_L @ alpha if timer: return test_err = ((pred_mean - test_G)**2).mean( ) # ((pred_mean-test_G)**2/np.diag(pred_cov)).mean()+(np.log(np.diag(pred_cov))).mean() norm = alpha.T @ L @ alpha Nfeval += 1 if prev_norm is not None: if norm[0, 0] / prev_norm >= 3: if opt_params is None: opt_test_err = test_err opt_params = params print(True, opt_params, opt_test_err, prev_norm) raise Exception if prev_norm is None or norm[0, 0] <= prev_norm: prev_norm = norm[0, 0] opt_test_err = test_err opt_params = params print('params,test_err, norm: ', opt_params, opt_test_err, prev_norm) train, dev, test = load_data(ROOT_PATH + '/data/zoo/{}_{}.npz'.format(sname, datasize)) X = np.vstack((train.x, dev.x)) Y = np.vstack((train.y, dev.y)) Z = np.vstack((train.z, dev.z)) test_X = test.x test_G = test.g t0 = time.time() EYEN = np.eye(X.shape[0]) ak = get_median_inter_mnist(Z) N2 = X.shape[0]**2 W0 = _sqdist(Z, None) W = (np.exp(-W0 / ak / ak / 2) + np.exp(-W0 / ak / ak / 200) + np.exp(-W0 / ak / ak * 50)) / 3 / N2 del W0 L0, test_L0 = _sqdist(X, None), _sqdist(test_X, X) # measure time # callback0(np.random.randn(2)/10,True) # np.save(ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + '/LMO_errs_{}_nystr_{}_time.npy'.format(seed,train.x.shape[0]),time.time()-t0) # return params0 = np.random.randn(2) / 10 bounds = None # [[0.01,10],[0.01,5]] if nystr: for _ in range(seed + 1): random_indices = np.sort( np.random.choice(range(W.shape[0]), nystr_M, replace=False)) eig_val_K, eig_vec_K = nystrom_decomp(W * N2, random_indices) inv_eig_val_K = np.diag(1 / eig_val_K / N2) W_nystr = eig_vec_K @ np.diag(eig_val_K) @ eig_vec_K.T / N2 W_nystr_Y = W_nystr @ Y obj_grad = value_and_grad(lambda params: LMO_err(params)) try: res = minimize(obj_grad, x0=params0, bounds=bounds, method='L-BFGS-B', jac=True, options={'maxiter': 5000}, callback=callback0) except Exception as e: print(e) PATH = ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + "/" os.makedirs(PATH, exist_ok=True) np.save(PATH + 'LMO_errs_{}_nystr_{}.npy'.format(seed, train.x.shape[0]), [opt_params, prev_norm, opt_test_err])
def main(input_file, input_no): other_params = get_other_params() data = np.load(input_file, allow_pickle=True) data = data.item() counter = 5 for trajcount in range(6): q_init = data[trajcount]['q_init'] q_fin = data[trajcount]['q_fin'][0] x_guess = utils.get_real_guess(data, trajcount, 0) xs = stack_x(q_init, q_fin) q, comp_time_1, msg = utils.solver_solution(x_guess, xs, other_params, cost_fun) print("Initial problem solved") for i in range(1, 7): n = data[trajcount]['q_fin'][i].shape[0] for j in range(n): folder = OUTPUT_FOLDER + "/" + input_no + "/traj" + str( trajcount) + '/' + str((i - 1) * counter + j) if not os.path.exists(folder): os.makedirs(folder) q_init_new = q_init.copy() q_fin_new = data[trajcount]['q_fin'][i][j, :] ys = q.copy() x_guess = ys.copy() xs_new = stack_x(q_init_new, q_fin_new) q_new, comp_time_2, msg = utils.solver_solution( x_guess, xs_new, other_params, cost_fun) print("Perturbed problem solved by solver") solver_sol = {} solver_sol['q'] = q solver_sol['q_new'] = q_new solver_sol['comp_time_1'] = comp_time_1 solver_sol['comp_time_2'] = comp_time_2 sfolder = SOLVER + "/" + input_no + "/traj" + str( trajcount) + '/' + str((i - 1) * counter + j) if not os.path.exists(sfolder): os.makedirs(sfolder) with open(sfolder + '/data.npy', 'wb') as f: np.save(f, solver_sol) print("Minimal cost from solver intial ", cost_fun(q, xs, other_params)) print("Minimal cost from solver perturbed ", cost_fun(q_new, xs_new, other_params)) etas = np.arange(0.05, 1.05, 0.05) niters = 30 q_pred, comp_time_3, saved_cost, saved_eta, saved_ys_pred = utils.argmin_solution(ys, xs, xs_new, other_params, \ etas, niters, cost_fun_jax, vmap_batched_update_ys, F_YY_fn, \ F_XY_fn, get_qfin_jax, stack_x, folder) print("Perturbed problem solved by argmin") print("Time taken : ", comp_time_3) save_problem_data(q_init, q_fin, q_init_new, q_fin_new, q, q_new, q_pred, \ comp_time_1, comp_time_2, comp_time_3, saved_cost, saved_eta, saved_ys_pred, other_params, folder) # plotting utils.plot_trajectory(q, q_new, q_pred, q_init, q_fin, q_init_new, q_fin_new, folder) utils.plot_end_effector_angles(q_new, q_pred, folder) utils.plot_joint_angles(q_new, q_pred, folder) print("-" * 50) print("*" * 50) print("=" * 80) print('done')
def experiment(sname, seed, nystr=True): def LMO_err(params, M=2, verbal=False): global Nfeval params = np.exp(params) al, bl = params[:-1], params[ -1] # params[:int(n_params/2)], params[int(n_params/2):] # [np.exp(e) for e in params] if train.x.shape[1] < 5: train_L = bl**2 * np.exp(-train_L0 / al**2 / 2) + 1e-4 * EYEN else: train_L, dev_L = 0, 0 for i in range(len(al)): train_L += train_L0[i] / al[i]**2 train_L = bl * bl * np.exp(-train_L / 2) + 1e-4 * EYEN tmp_mat = train_L @ eig_vec_K C = train_L - tmp_mat @ np.linalg.inv(eig_vec_K.T @ tmp_mat / N2 + inv_eig_val) @ tmp_mat.T / N2 c = C @ W_nystr_Y * N2 c_y = c - train.y lmo_err = 0 N = 0 for ii in range(1): permutation = np.random.permutation(train.x.shape[0]) for i in range(0, train.x.shape[0], M): indices = permutation[i:i + M] K_i = train_W[np.ix_(indices, indices)] * N2 C_i = C[np.ix_(indices, indices)] c_y_i = c_y[indices] b_y = np.linalg.inv(np.eye(M) - C_i @ K_i) @ c_y_i lmo_err += b_y.T @ K_i @ b_y N += 1 return lmo_err[0, 0] / M**2 def callback0(params): global Nfeval, prev_norm, opt_params, opt_test_err if Nfeval % 1 == 0: params = np.exp(params) print('params:', params) al, bl = params[:-1], params[-1] if train.x.shape[1] < 5: train_L = bl**2 * np.exp(-train_L0 / al**2 / 2) + 1e-4 * EYEN test_L = bl**2 * np.exp(-test_L0 / al**2 / 2) else: train_L, test_L = 0, 0 for i in range(len(al)): train_L += train_L0[i] / al[i]**2 test_L += test_L0[i] / al[i]**2 train_L = bl * bl * np.exp(-train_L / 2) + 1e-4 * EYEN test_L = bl * bl * np.exp(-test_L / 2) if nystr: tmp_mat = eig_vec_K.T @ train_L alpha = EYEN - eig_vec_K @ np.linalg.inv( tmp_mat @ eig_vec_K / N2 + inv_eig_val) @ tmp_mat / N2 alpha = alpha @ W_nystr_Y * N2 else: LWL_inv = chol_inv(train_L @ train_W @ train_L + train_L / N2 + JITTER * EYEN) alpha = LWL_inv @ train_L @ train_W @ train.y pred_mean = test_L @ alpha test_err = ((pred_mean - test.g)**2).mean() norm = alpha.T @ train_L @ alpha Nfeval += 1 if prev_norm is not None: if norm[0, 0] / prev_norm >= 3: if opt_test_err is None: opt_test_err = test_err opt_params = params print(True, opt_params, opt_test_err, prev_norm, norm[0, 0]) raise Exception if prev_norm is None or norm[0, 0] <= prev_norm: prev_norm = norm[0, 0] opt_test_err = test_err opt_params = params print(True, opt_params, opt_test_err, prev_norm, norm[0, 0]) train, dev, test = load_data(ROOT_PATH + '/data/' + sname + '/main_orig.npz') del dev # avoid same indices when run on the cluster for _ in range(seed + 1): random_indices = np.sort( np.random.choice(range(train.x.shape[0]), nystr_M, replace=False)) EYEN = np.eye(train.x.shape[0]) N2 = train.x.shape[0]**2 # precompute to save time on parallized computation if train.z.shape[1] < 5: ak = get_median_inter_mnist(train.z) else: ak = np.load(ROOT_PATH + '/mnist_precomp/{}_ak.npy'.format(sname)) train_W = np.load(ROOT_PATH + '/mnist_precomp/{}_train_K0.npy'.format(sname)) train_W = (np.exp(-train_W / ak / ak / 2) + np.exp( -train_W / ak / ak / 200) + np.exp(-train_W / ak / ak * 50)) / 3 / N2 if train.x.shape[1] < 5: train_L0 = _sqdist(train.x, None) test_L0 = _sqdist(test.x, train.x) else: L0s = np.load(ROOT_PATH + '/mnist_precomp/{}_Ls.npz'.format(sname)) train_L0 = L0s['train_L0'] # dev_L0 = L0s['dev_L0'] test_L0 = L0s['test_L0'] del L0s if train.x.shape[1] < 5: params0 = np.random.randn(2) * 0.1 else: params0 = np.random.randn(len(train_L0) + 1) * 0.1 bounds = None eig_val_K, eig_vec_K = nystrom_decomp(train_W * N2, random_indices) W_nystr_Y = eig_vec_K @ np.diag(eig_val_K) @ eig_vec_K.T @ train.y / N2 inv_eig_val = np.diag(1 / eig_val_K / N2) obj_grad = value_and_grad(lambda params: LMO_err(params)) res = minimize(obj_grad, x0=params0, bounds=bounds, method='L-BFGS-B', jac=True, options={ 'maxiter': 5000, 'disp': True, 'ftol': 0 }, callback=callback0) PATH = ROOT_PATH + "/MMR_IVs/results/" + sname + "/" os.makedirs(PATH, exist_ok=True) np.save(PATH + 'LMO_errs_{}_nystr.npy'.format(seed), [opt_params, prev_norm, opt_test_err])
print(np.max(log_iw)) print(log_iw.shape) psis_lw, K_hat_stan = psislw(log_iw.T) K_hat_stan_advi_list[j, n] = K_hat_stan print(psis_lw.shape) print('K hat statistic for Stan ADVI:') print(K_hat_stan) ###################### Plotting L2 norm here ################################# plt.figure() plt.plot(stan_vb_w[:, 0], stan_vb_w[:, 1], 'mo', label='STAN-ADVI') plt.savefig('vb_w_samples_mf.pdf') np.save('K_hat_logistic_' + datatype + '_' + algo_name + '_' + str(N) + 'N', K_hat_stan_advi_list) plt.figure() plt.plot(K_list, np.nanmean(K_hat_stan_advi_list, axis=1), 'r-', alpha=1) plt.plot(K_list, np.nanmin(K_hat_stan_advi_list, axis=1), 'r-', alpha=0.5) plt.plot(K_list, np.nanmax(K_hat_stan_advi_list, axis=1), 'r-', alpha=0.5) plt.xlabel('Dimensions') plt.ylabel('K-hat') np.save( 'K_hat_logistic_' + datatype + '_' + algo_name + '_' + str(N) + 'N' + '_samples_' + str(gradsamples), K_hat_stan_advi_list) #plt.ylim((0,5)) plt.legend() plt.savefig('Logistic_Regression_K_hat_vs_D_' + datatype + '_' + algo_name + '_' + str(N) + 'N.pdf')
print("fidelity reached : ", fidelity_reached) update_error_list.append(1. - fidelity_reached) current_energy = np.sum(mps_func.expectation_values( A_list, H_list)) E_list.append(current_energy) t_list.append(t_list[-1] + dt) print(t_list[-1], E_list[-1]) dir_path = 'data/1d_%s_g%.1f/L%d/' % (Hamiltonian, g, L) if not os.path.exists(dir_path): os.makedirs(dir_path) filename = 'mps_chi%d_%s_energy.npy' % (chi, order) path = dir_path + filename np.save(path, np.array(E_list)) filename = 'mps_chi%d_%s_dt.npy' % (chi, order) path = dir_path + filename np.save(path, np.array(t_list)) filename = 'mps_chi%d_%s_error.npy' % (chi, order) path = dir_path + filename np.save(path, np.array(update_error_list)) dir_path = 'data/1d_%s_g%.1f/' % (Hamiltonian, g) best_E = np.amin(E_list) filename = 'mps_chi%d_%s_energy.csv' % (chi, order) path = dir_path + filename # Try to load file # If data return
num_seed = 123 npr.seed(num_seed) params_R = np.zeros((n_latent,2)) steps = np.ones((n_latent,2)) sCur_R = np.zeros((n_latent,2)) ELBO_R = np.zeros(n_iter) params_R[:,0] = 0.5+sigma*npr.normal(size=n_latent) params_R[:,1] = sigma*npr.normal(size=n_latent) transformVar = np.log(1.+np.exp(params_R)) ELBO_R[0] = estimate_elbo(transformVar[:,0],transformVar[:,1],K,x,alphaz) for n in range(1,n_iter): sGrad = reparam_gradient(transformVar[:,0],transformVar[:,1],x,K,alphaz,corr=correction,B=B)/(1.+np.exp(-params_R)) steps,sCur_R = stepSize(n+1,sCur_R,sGrad,eta) params_R = truncate_params(params_R+steps*sGrad) transformVar = np.log(1.+np.exp(params_R)) ELBO_R[n] = estimate_elbo(transformVar[:,0],transformVar[:,1],K,x,alphaz) if np.mod(n,100) == 0: filename = 'results/Olivette_Eta'+str(eta)+'_B'+str(B)+'_corr'+str(correction)+'_ELBO.npy' np.save(filename, ELBO_R[:n_iter]) filename = 'results/Olivette_Eta'+str(eta)+'_B'+str(B)+'_corr'+str(correction)+'_K1_'+str(K[0])+'_K2_'+str(K[1])+'_K3_'+str(K[2])+'_params_R.npy' np.save(filename,np.log(1.+np.exp(params_R))) filename = 'results/Olivette_Eta'+str(eta)+'_B'+str(B)+'_corr'+str(correction)+'_ELBO.npy' np.save(filename, ELBO_R[:n_iter]) filename = 'results/Olivette_Eta'+str(eta)+'_B'+str(B)+'_corr'+str(correction)+'_K1_'+str(K[0])+'_K2_'+str(K[1])+'_K3_'+str(K[2])+'_params_R.npy' np.save(filename,np.log(1.+np.exp(params_R)))