def run(spikes, order, window=1, map_function='nr', lmbda=0.005, max_iter=10): """ Master-function of the State-Space Analysis of Spike Correlation package. Uses the expectation-maximisation algorithm to find the probability distributions of natural parameters of spike-train interactions over time. Calls slave functions to perform the expectation and maximisation steps repeatedly until the data likelihood reaches an asymptotic value. Note that the execution of some slave functions to this master function are of exponential complexity with respect to the `order' parameter. :param numpy.ndarray spikes: Binary matrix with dimensions (time, runs, cells), in which a `1' in location (t, r, c) denotes a spike at time t in run r by cell c. :param int order: Order of spike-train interactions to estimate, for example, 2 = pairwise, 3 = triplet-wise... :param int window: Bin-width for counting spikes, in milliseconds. :param string map_function: Name of the function to use for maximum a-posterior estimation of the natural parameters at each timestep. Refer to max_posterior.py. :param float lmdbda: Coefficient on the identity matrix of the initial state-transition covariance matrix. :param int max_iter: Maximum number of iterations for which to run the EM algorithm. :returns: Results encapsulated in a container.EMData object, containing the smoothed posterior probability distributions of the natural parameters of the spike-train interactions at each timestep, conditional upon the given spikes. """ # Ensure NaNs are caught numpy.seterr(invalid='raise', under='raise') # Initialise the EM-data container map_func = max_posterior.functions[map_function] emd = container.EMData(spikes, order, window, map_func, lmbda) # Initialise the coordinate-transform maps transforms.initialise(emd.N, emd.order) # Set up loop guards for the EM algorithm lmp = -numpy.inf lmc = probability.log_marginal(emd) # Iterate the EM algorithm until convergence or failure while (emd.iterations < max_iter) and (emd.convergence > exp_max.CONVERGED): # Perform EM exp_max.e_step(emd) exp_max.m_step(emd) # Update previous and current log marginal values lmp = lmc lmc = probability.log_marginal(emd) # Update EM algorithm metadata emd.iterations += 1 emd.convergence = lmp / lmc return emd
def __init__(self, spikes, order, window, param_est, param_est_eta, map_function, lmbda1, lmbda2): # Record the input parameters self.spikes, self.order, self.window = spikes, order, window T, self.R, self.N = self.spikes.shape if param_est == 'exact': transforms.initialise(self.N, self.order) self.max_posterior = max_posterior.functions[map_function] elif param_est == 'pseudo': pseudo_likelihood.compute_Fx_s(self.spikes, self.order) self.max_posterior = pseudo_likelihood.functions[map_function] self.param_est_theta = param_est self.param_est_eta = param_est_eta self.marg_llk = log_marginal_functions[param_est_eta] # Compute the `sample' spike-train interactions from the input spikes self.y = transforms.compute_y(self.spikes, self.order, self.window) # Count timesteps, trials, cells and interaction dimensions self.T, self.D = self.y.shape assert self.T == T / window # Initialise one-step-prediction- filtered- smoothed-density means self.theta_o = numpy.zeros((self.T,self.D)) self.theta_f = numpy.zeros((self.T,self.D)) self.theta_s = numpy.zeros((self.T,self.D)) # Initialise covariances of the same (an I-matrix for each timestep) if param_est == 'exact': I = [numpy.identity(self.D) for i in range(self.T)] I = numpy.vstack(I).reshape((self.T,self.D,self.D)) self.sigma_o = .1 * I self.sigma_o_inv = 1./.1 * I del I # Intialise autoregressive and transition probability hyperparameters self.sigma_f = numpy.copy(self.sigma_o) self.sigma_s = numpy.copy(self.sigma_o) self.sigma_s_lag = numpy.copy(self.sigma_o) # For approximate term initialize only the diagonal of the convariances else: self.sigma_o = .1*numpy.ones((self.T,self.D)) self.sigma_o_inv = 1./.1*numpy.ones((self.T,self.D)) self.sigma_f = .1*numpy.ones((self.T,self.D)) self.sigma_s = .1*numpy.ones((self.T,self.D)) self.sigma_s_lag = .1*numpy.ones((self.T,self.D)) self.F = numpy.identity(self.D) self.Q = numpy.zeros([self.D, self.D]) self.Q[:self.N, :self.N] = 1. / lmbda1 * numpy.identity(self.N) self.Q[self.N:, self.N:] = 1. / lmbda2 * numpy.identity(self.D - self.N) self.mllk = numpy.inf # Metadata about EM algorithm execution self.iterations, self.convergence = 0, numpy.inf
def __init__(self, spikes, order, window, param_est, param_est_eta, map_function, lmbda1, lmbda2, theta_o, sigma_o): # Record the input parameters self.spikes, self.order, self.window = spikes, order, window T, self.R, self.N = self.spikes.shape if param_est == 'exact': transforms.initialise(self.N, self.order) self.max_posterior = max_posterior.functions[map_function] elif param_est == 'pseudo': pseudo_likelihood.compute_Fx_s(self.spikes, self.order) self.max_posterior = pseudo_likelihood.functions[map_function] self.param_est_theta = param_est self.param_est_eta = param_est_eta self.marg_llk = log_marginal_functions[param_est_eta] # Compute the `sample' spike-train interactions from the input spikes self.y = transforms.compute_y(self.spikes, self.order, self.window) # Count timesteps, trials, cells and interaction dimensions self.T, self.D = self.y.shape assert self.T == T / window # Initialise one-step-prediction- filtered- smoothed-density means self.theta_o = numpy.ones((self.T,self.D)) * theta_o self.theta_f = numpy.zeros((self.T,self.D)) self.theta_s = numpy.zeros((self.T,self.D)) # Initialise covariances of the same (an I-matrix for each timestep) if param_est == 'exact': I = [numpy.identity(self.D) for i in range(self.T)] I = numpy.vstack(I).reshape((self.T,self.D,self.D)) self.sigma_o = sigma_o * I self.sigma_o_inv = numpy.linalg.inv(self.sigma_o) del I # Intialise autoregressive and transition probability hyperparameters self.sigma_f = numpy.copy(self.sigma_o) self.sigma_s = numpy.copy(self.sigma_o) self.sigma_s_lag = numpy.copy(self.sigma_o) # For approximate term initialize only the diagonal of the covariances else: self.sigma_o = sigma_o * numpy.ones((self.T,self.D)) self.sigma_o_inv = 1./ sigma_o * numpy.ones((self.T,self.D)) self.sigma_f = .1*numpy.ones((self.T,self.D)) self.sigma_s = .1*numpy.ones((self.T,self.D)) self.sigma_s_lag = .1*numpy.ones((self.T,self.D)) self.F = numpy.identity(self.D) self.Q = numpy.zeros([self.D, self.D]) self.Q[:self.N, :self.N] = 1. / lmbda1 * numpy.identity(self.N) self.Q[self.N:, self.N:] = 1. / lmbda2 * numpy.identity(self.D - self.N) self.mllk = numpy.inf # Metadata about EM algorithm execution self.iterations, self.convergence = 0, numpy.inf
def compute_eta(theta, N, O, R=1000): """ Computes eta from given theta. :param numpy.ndarray theta: (t, d) array with natural parameters :param int N: number of cells :param int O: order of model :param int R: trials that should be sampled to estimate eta :return numpy.ndarray, list: (t, d) array with natural parameters parameters and a list with indices of bins, for which has been sampled Details: Tries to estimate eta by solving the forward problem from TAP. However, if it fails we fall back to sampling. For networks with less then 15 neurons exact solution is computed and for first order analytical solution is used. """ T, D = theta.shape eta = numpy.empty(theta.shape) bins_to_sample = [] if O == 1: eta = compute_ind_eta(theta[:, :N]) elif O == 2: # if few cells compute exact rates if N > 15: for i in range(T): # try to solve forward problem try: eta[i] = mean_field.forward_problem(theta[i], N, 'TAP') # if it fails remember bin for sampling except Exception: bins_to_sample.append(i) if len(bins_to_sample) != 0: theta_to_sample = numpy.empty([len(bins_to_sample), D]) for idx, bin2sampl in enumerate(bins_to_sample): theta_to_sample[idx] = theta[bin2sampl] spikes = synthesis.generate_spikes_gibbs_parallel( theta_to_sample, N, O, R, sample_steps=100) eta_from_sample = transforms.compute_y(spikes, O, 1) for idx, bin2sampl in enumerate(bins_to_sample): eta[bin2sampl] = eta_from_sample[idx] # if large ensemble approximate else: transforms.initialise(N, O) for i in range(T): p = transforms.compute_p(theta[i]) eta[i] = transforms.compute_eta(p) return eta, bins_to_sample
def compute_eta(theta, N, O, R=1000): """ Computes eta from given theta. :param numpy.ndarray theta: (t, d) array with natural parameters :param int N: number of cells :param int O: order of model :param int R: trials that should be sampled to estimate eta :return numpy.ndarray, list: (t, d) array with natural parameters parameters and a list with indices of bins, for which has been sampled Details: Tries to estimate eta by solving the forward problem from TAP. However, if it fails we fall back to sampling. For networks with less then 15 neurons exact solution is computed and for first order analytical solution is used. """ T, D = theta.shape eta = numpy.empty(theta.shape) bins_to_sample = [] if O == 1: eta = compute_ind_eta(theta[:,:N]) elif O == 2: # if few cells compute exact rates if N > 15: for i in range(T): # try to solve forward problem try: eta[i] = mean_field.forward_problem(theta[i], N, 'TAP') # if it fails remember bin for sampling except Exception: bins_to_sample.append(i) if len(bins_to_sample) != 0: theta_to_sample = numpy.empty([len(bins_to_sample), D]) for idx, bin2sampl in enumerate(bins_to_sample): theta_to_sample[idx] = theta[bin2sampl] spikes = synthesis.generate_spikes_gibbs_parallel(theta_to_sample, N, O, R, sample_steps=100) eta_from_sample = transforms.compute_y(spikes, O, 1) for idx, bin2sampl in enumerate(bins_to_sample): eta[bin2sampl] = eta_from_sample[idx] # if large ensemble approximate else: transforms.initialise(N, O) for i in range(T): p = transforms.compute_p(theta[i]) eta[i] = transforms.compute_eta(p) return eta, bins_to_sample
def run_ssasc(self, theta, N, O): # Initialise the library for computing pattern probabilities transforms.initialise(N, O) # Compute probability from theta values p = numpy.zeros((self.T, 2**N)) for i in xrange(self.T): p[i,:] = transforms.compute_p(theta[i,:]) # Generate spikes according to those probabilities spikes = synthesis.generate_spikes(p, self.R, seed=self.spike_seed) # Run the algorithm! emd = __init__.run(spikes, O) # Compute the KL divergence between real and estimated parameters kld = klic(theta, emd.theta_s, emd.N) # Check that KL divergence is OK if numpy.any(kld[50:-50] > .01): self.plot(theta, emd.theta_s, emd.sigma_s, emd.y, kld, emd.N, emd.T, emd.D) self.assertFalse(numpy.any(kld[50:-50] > .01))
def compute_psi(theta, N, O, R=1000): """ Computes psi from given theta. :param numpy.ndarray theta: (t, d) array with natural parameters :param int N: number of cells :param int O: order of model :param int R: trials that should be sampled to estimate eta :return numpy.ndarray, list: (t, d) array with log-partition and a list with indices of bins, for which has been sampled For first order the analytical solution is used. For networks with 15 units and less the exact solution is computed. Otherwise, the Ogata-Tanemura-Estimator is used. It tries to solve the forward problem and samples where it fails. """ T = theta.shape[0] bins_sampled = [] psi = numpy.empty(T) if O == 1: psi = compute_ind_psi(theta[:,:N]) if O == 2: # if few cells compute exact result if N > 15: theta0 = numpy.copy(theta) theta0[:,N:] = 0 psi0 = compute_ind_psi(theta0[:,:N]) for i in range(T): psi[i], sampled = ot_estimator(theta0[i], psi0[i], theta[i], N, O, N) # save bin if sampled if sampled: bins_sampled.append(i) # else approximate else: transforms.initialise(N, 2) for i in range(T): psi[i] = transforms.compute_psi(theta[i]) return psi, bins_sampled
def compute_psi(theta, N, O, R=1000): """ Computes psi from given theta. :param numpy.ndarray theta: (t, d) array with natural parameters :param int N: number of cells :param int O: order of model :param int R: trials that should be sampled to estimate eta :return numpy.ndarray, list: (t, d) array with log-partition and a list with indices of bins, for which has been sampled For first order the analytical solution is used. For networks with 15 units and less the exact solution is computed. Otherwise, the Ogata-Tanemura-Estimator is used. It tries to solve the forward problem and samples where it fails. """ T = theta.shape[0] bins_sampled = [] psi = numpy.empty(T) if O == 1: psi = compute_ind_psi(theta[:, :N]) if O == 2: # if few cells compute exact result if N > 15: theta0 = numpy.copy(theta) theta0[:, N:] = 0 psi0 = compute_ind_psi(theta0[:, :N]) for i in range(T): psi[i] = ot_estimator(theta0[i], psi0[i], theta[i], N, O, N) # else approximate else: transforms.initialise(N, 2) for i in range(T): psi[i] = transforms.compute_psi(theta[i]) return psi
# Number of dimensions of the generative model D = 2 # Number of graphs in the fitted model D_fit = 2 # Precision lmbda = 1000 # Random seed seed = numpy.random.seed(0) # p_map and e_map (Initialise) if exact: transforms.initialise(N, 2) # Number of dimensions of the graphs dim = transforms.compute_D(N, 2) # Generative J matrix J_gen = numpy.zeros((dim, D_fit)) x1 = numpy.array([-1, 1, -1, 1, 1, 1, -1, 1, -1]) x2 = numpy.array([1, 1, 1, -1, 1, -1, -1, 1, -1]) a1 = numpy.outer(x1, x1) a2 = numpy.outer(x2, x2) J_gen[:N, 0] = numpy.diag(a1) J_gen[:N, 1] = numpy.diag(a2) J_gen[N:, 0] = a1[numpy.triu_indices(N, k=1)] J_gen[N:, 1] = a2[numpy.triu_indices(N, k=1)] J_gen = J_gen / (numpy.var(J_gen, axis=0)**0.5)
def generate_data_ctime(data_path='../Data/', max_network_size=60, num_procs=4): N, O, R, T = 10, 2, 200, 500 num_of_networks = max_network_size/N mu = numpy.zeros(T) x = numpy.arange(1, 401) mu[100:] = 1. * (3. / (2. * numpy.pi * (x / 400. * 3.) ** 3)) ** .5 * \ numpy.exp(-3. * ((x / 400. * 3.) - 1.) ** 2 / (2. * (x / 400. * 3.))) D = transforms.compute_D(N, O) thetas = numpy.empty([num_of_networks, T, D]) transforms.initialise(N, O) for i in range(num_of_networks): thetas[i] = synthesis.generate_thetas(N, O, T, mu1=-2.) thetas[i, :, :N] += mu[:, numpy.newaxis] R = 500 f = h5py.File(data_path + 'comp_time_data.h5', 'w') f.create_dataset('N', data=numpy.arange(N, max_network_size+N, N)) f.create_dataset('ctime', shape=[2,num_of_networks]) f.close() for i in range(num_of_networks): print 'N=%d' % ((i + 1) * N) D = transforms.compute_D((i + 1) * N, O) theta_all = numpy.empty([T, D]) triu_idx = numpy.triu_indices(N, k=1) triu_idx_all = numpy.triu_indices((i + 1) * N, k=1) for j in range(i + 1): theta_all[:, N * j:(j + 1) * N] = thetas[j, :, :N] for t in range(T): theta_ij = numpy.zeros([(i + 1) * N, (i + 1) * N]) for j in range(i + 1): theta_ij[triu_idx[0] + j * N, triu_idx[1] + j * N] = \ thetas[j, t, N:] theta_all[t, (i + 1) * N:] = theta_ij[triu_idx_all] spikes = synthesis.generate_spikes_gibbs_parallel(theta_all, (i + 1) * N, O, R, sample_steps=10, num_proc=num_procs) t1 = time.time() result = __init__.run(spikes, O, map_function='cg', param_est='pseudo', param_est_eta='bethe_hybrid', lmbda1=100, lmbda2=200) t2 = time.time() ctime_bethe = t2 - t1 f = h5py.File(data_path + 'comp_time_data.h5', 'r+') f['ctime'][0, i] = ctime_bethe f.close() try: t1 = time.time() result = __init__.run(spikes, O, map_function='cg', param_est='pseudo', param_est_eta='mf', lmbda1=100, lmbda2=200) t2 = time.time() ctime_TAP = t2 - t1 except Exception: ctime_TAP = numpy.nan f = h5py.File(data_path + 'comp_time_data.h5', 'r+') f['ctime'][1, i] = ctime_TAP f.close()
def generate_data_figure3and4(data_path = '../Data/', num_of_iterations=10): R, T, N, O = 200, 500, 15, 2 f = h5py.File(data_path + 'figure1data.h5', 'r') theta = f['data']['theta1'].value f.close() transforms.initialise(N, O) psi_true = numpy.empty(T) for i in range(T): psi_true[i] = transforms.compute_psi(theta[i]) p = numpy.zeros((T, 2 ** N)) for i in range(T): p[i, :] = transforms.compute_p(theta[i, :]) fitting_methods = ['exact', 'bethe_hybrid', 'mf'] f = h5py.File(data_path + 'figure2and3data.h5', 'w') f.create_dataset('psi_true', data=psi_true) f.create_dataset('theta_true', data=theta) for fit in fitting_methods: g = f.create_group(fit) g.create_dataset('MISE_theta', shape=[num_of_iterations]) g.create_dataset('MISE_psi', shape=[num_of_iterations]) g.create_dataset('psi', shape=[num_of_iterations, T]) f.close() for iteration in range(num_of_iterations): print 'Iteration %d' % iteration spikes = synthesis.generate_spikes(p, R, seed=None) for fit in fitting_methods: if fit == 'exact': emd = __init__.run(spikes, O, map_function='cg', param_est='exact', param_est_eta='exact') else: emd = __init__.run(spikes, O, map_function='cg', param_est='pseudo', param_est_eta=fit) psi = numpy.empty(T) if fit == 'exact': for i in range(T): psi[i] = transforms.compute_psi(emd.theta_s[i]) elif fit == 'bethe_hybrid': for i in range(T): psi[i] = bethe_approximation.compute_eta_hybrid( emd.theta_s[i], N, return_psi=1)[1] elif fit == 'mf': for i in range(T): eta_mf = mean_field.forward_problem(emd.theta_s[i], N, 'TAP') psi[i] = mean_field.compute_psi(emd.theta_s[i], eta_mf, N) mise_theta = numpy.mean((theta - emd.theta_s) ** 2) mise_psi = numpy.mean((psi_true - psi) ** 2) f = h5py.File(data_path + 'figure2and3data.h5', 'r+') g = f[fit] g['MISE_theta'][iteration] = mise_theta g['MISE_psi'][iteration] = mise_psi if iteration == 0: g.create_dataset('theta', data=emd.theta_s) g.create_dataset('sigma', data=emd.sigma_s) g['psi'][iteration] = psi f.close() print 'Fitted with %s' % fit
def generate_data_figure2(data_path='../Data/', max_network_size=60): N, O, R, T = 10, 2, 200, 500 num_of_networks = max_network_size/N mu = numpy.zeros(T) x = numpy.arange(1, 401) mu[100:] = 1. * (3. / (2. * numpy.pi * (x / 400. * 3.) ** 3)) ** .5 * \ numpy.exp(-3. * ((x / 400. * 3.) - 1.) ** 2 / (2. * (x / 400. * 3.))) D = transforms.compute_D(N, O) thetas = numpy.empty([num_of_networks, T, D]) etas = numpy.empty([num_of_networks, T, D]) psi = numpy.empty([num_of_networks, T]) S = numpy.empty([num_of_networks, T]) C = numpy.empty([num_of_networks, T]) transforms.initialise(N, O) for i in range(num_of_networks): thetas[i] = synthesis.generate_thetas(N, O, T, mu1=-2.) thetas[i, :, :N] += mu[:, numpy.newaxis] for t in range(T): p = transforms.compute_p(thetas[i, t]) etas[i, t] = transforms.compute_eta(p) psi[i, t] = transforms.compute_psi(thetas[i, t]) psi1 = transforms.compute_psi(.999 * thetas[i, t]) psi2 = transforms.compute_psi(1.001 * thetas[i, t]) C[i, t] = (psi1 - 2. * psi[i, t] + psi2) / .001 ** 2 S[i, t] = -(numpy.sum(etas[i, t] * thetas[i, t]) - psi[i, t]) C /= numpy.log(2) S /= numpy.log(2) f = h5py.File(data_path + 'figure2data.h5', 'w') g1 = f.create_group('data') g1.create_dataset('thetas', data=thetas) g1.create_dataset('etas', data=etas) g1.create_dataset('psi', data=psi) g1.create_dataset('S', data=S) g1.create_dataset('C', data=C) g2 = f.create_group('error') g2.create_dataset('MISE_thetas', shape=[num_of_networks]) g2.create_dataset('MISE_population_rate', shape=[num_of_networks]) g2.create_dataset('MISE_psi', shape=[num_of_networks]) g2.create_dataset('MISE_S', shape=[num_of_networks]) g2.create_dataset('MISE_C', shape=[num_of_networks]) g2.create_dataset('population_rate', shape=[num_of_networks, T]) g2.create_dataset('psi', shape=[num_of_networks, T]) g2.create_dataset('S', shape=[num_of_networks, T]) g2.create_dataset('C', shape=[num_of_networks, T]) f.close() for i in range(num_of_networks): print 'N=%d' % ((i + 1) * N) D = transforms.compute_D((i + 1) * N, O) theta_all = numpy.empty([T, D]) triu_idx = numpy.triu_indices(N, k=1) triu_idx_all = numpy.triu_indices((i + 1) * N, k=1) for j in range(i + 1): theta_all[:, N * j:(j + 1) * N] = thetas[j, :, :N] for t in range(T): theta_ij = numpy.zeros([(i + 1) * N, (i + 1) * N]) for j in range(i + 1): theta_ij[triu_idx[0] + j * N, triu_idx[1] + j * N] = \ thetas[j, t, N:] theta_all[t, (i + 1) * N:] = theta_ij[triu_idx_all] spikes = synthesis.generate_spikes_gibbs_parallel(theta_all , (i + 1) * N, O, R, sample_steps=10, num_proc=4) emd = __init__.run(spikes, O, map_function='cg', param_est='pseudo', param_est_eta='bethe_hybrid', lmbda1=100, lmbda2=200) eta_est = numpy.empty(emd.theta_s.shape) psi_est = numpy.empty(T) S_est = numpy.empty(T) C_est = numpy.empty(T) for t in range(T): eta_est[t], psi_est[t] = bethe_approximation.compute_eta_hybrid( emd.theta_s[t], (i + 1) * N, return_psi=1) psi1 = bethe_approximation.compute_eta_hybrid( .999 * emd.theta_s[t], (i + 1) * N, return_psi=1)[1] psi2 = bethe_approximation.compute_eta_hybrid( 1.001 * emd.theta_s[t], (i + 1) * N, return_psi=1)[1] S_est[t] = -(numpy.sum(eta_est[t] * emd.theta_s[t]) - psi_est[t]) C_est[t] = (psi1 - 2. * psi_est[t] + psi2) / .001 ** 2 S_est /= numpy.log(2) C_est /= numpy.log(2) population_rate = numpy.mean(numpy.mean(etas[:i + 1, :, :N], axis=0), axis=1) population_rate_est = numpy.mean(eta_est[:, :(i + 1) * N], axis=1) psi_true = numpy.sum(psi[:(i + 1), :], axis=0) S_true = numpy.sum(S[:(i + 1), :], axis=0) C_true = numpy.sum(C[:(i + 1), :], axis=0) f = h5py.File(data_path + 'figure2data.h5', 'r+') f['error']['MISE_thetas'][i] = numpy.mean( (theta_all - emd.theta_s) ** 2) f['error']['MISE_population_rate'][i] = numpy.mean( (population_rate - population_rate_est) ** 2) f['error']['MISE_psi'][i] = numpy.mean((psi_est - psi_true) ** 2) f['error']['MISE_S'][i] = numpy.mean((S_est - S_true) ** 2) f['error']['MISE_C'][i] = numpy.mean((C_est - C_true) ** 2) f['error']['population_rate'][i] = population_rate_est f['error']['psi'][i] = psi_est f['error']['S'][i] = S_est f['error']['C'][i] = C_est f.close() f = h5py.File(data_path + 'figure2data.h5', 'r+') thetas = f['data']['thetas'].value etas = f['data']['etas'].value psi = f['data']['psi'].value S = f['data']['S'].value C = f['data']['C'].value g2 = f.create_group('error500') g2.create_dataset('population_rate', shape=[num_of_networks, T]) g2.create_dataset('psi', shape=[num_of_networks, T]) g2.create_dataset('S', shape=[num_of_networks, T]) g2.create_dataset('C', shape=[num_of_networks, T]) g2.create_dataset('MISE_thetas', shape=[num_of_networks]) g2.create_dataset('MISE_population_rate', shape=[num_of_networks]) g2.create_dataset('MISE_psi', shape=[num_of_networks]) g2.create_dataset('MISE_S', shape=[num_of_networks]) g2.create_dataset('MISE_C', shape=[num_of_networks]) f.close() R = 500 for i in range(num_of_networks): print 'N=%d' % ((i + 1) * N) D = transforms.compute_D((i + 1) * N, O) theta_all = numpy.empty([T, D]) triu_idx = numpy.triu_indices(N, k=1) triu_idx_all = numpy.triu_indices((i + 1) * N, k=1) for j in range(i + 1): theta_all[:, N * j:(j + 1) * N] = thetas[j, :, :N] for t in range(T): theta_ij = numpy.zeros([(i + 1) * N, (i + 1) * N]) for j in range(i + 1): theta_ij[triu_idx[0] + j * N, triu_idx[1] + j * N] = \ thetas[j, t, N:] theta_all[t, (i + 1) * N:] = theta_ij[triu_idx_all] spikes = synthesis.generate_spikes_gibbs_parallel(theta_all, (i + 1) * N, O, R, sample_steps=10, num_proc=4) emd = __init__.run(spikes, O, map_function='cg', param_est='pseudo', param_est_eta='bethe_hybrid', lmbda1=100, lmbda2=200) eta_est = numpy.empty(emd.theta_s.shape) psi_est = numpy.empty(T) S_est = numpy.empty(T) C_est = numpy.empty(T) for t in range(T): eta_est[t], psi_est[t] = \ bethe_approximation.compute_eta_hybrid(emd.theta_s[t], (i + 1) * N, return_psi=1) psi1 = bethe_approximation.compute_eta_hybrid(.999 * emd.theta_s[t], (i + 1) * N, return_psi=1)[1] psi2 = bethe_approximation.compute_eta_hybrid( 1.001 * emd.theta_s[t], (i + 1) * N, return_psi=1)[1] S_est[t] = -(numpy.sum(eta_est[t] * emd.theta_s[t]) - psi_est[t]) C_est[t] = (psi1 - 2. * psi_est[t] + psi2) / .001 ** 2 S_est /= numpy.log(2) C_est /= numpy.log(2) population_rate = numpy.mean(numpy.mean(etas[:i + 1, :, :N], axis=0), axis=1) population_rate_est = numpy.mean(eta_est[:, :(i + 1) * N], axis=1) psi_true = numpy.sum(psi[:(i + 1), :], axis=0) S_true = numpy.sum(S[:(i + 1), :], axis=0) C_true = numpy.sum(C[:(i + 1), :], axis=0) f = h5py.File(data_path + 'figure2data.h5', 'r+') f['error500']['MISE_thetas'][i] = numpy.mean( (theta_all - emd.theta_s) ** 2) f['error500']['MISE_population_rate'][i] = numpy.mean( (population_rate - population_rate_est) ** 2) f['error500']['MISE_psi'][i] = numpy.mean((psi_est - psi_true) ** 2) f['error500']['MISE_S'][i] = numpy.mean((S_est - S_true) ** 2) f['error500']['MISE_C'][i] = numpy.mean((C_est - C_true) ** 2) f['error500']['population_rate'][i] = population_rate_est f['error500']['psi'][i] = psi_est f['error500']['S'][i] = S_est f['error500']['C'][i] = C_est f.close()
# Number of time bins, trials and neurons T, R, N = emd.spikes.shape # Number of dimensions D = emd.D # Order of interactions considered in the Ising models O = emd.order # Initialise the theta matrix and the sigma matrix theta_m = numpy.zeros((12, T, D)) theta_m[0, :, :] = emd.theta_s sigma_m = numpy.zeros((12, T, D)) if emd.marg_llk == probability.log_marginal: # i.e. if exact model is used transforms.initialise(N, O) for t in range(T): sigma_m[0, t] = numpy.diag( emd.sigma_s[t] ) # Because emd.sigma_s is ((T,D,D)) if exact, only need the diagonal for fig1 else: sigma_m[0] = emd.sigma_s # Initialize the spike train array spikes_m = numpy.zeros((12, T, R, N)) spikes_m[0, :, :, :] = emd.spikes # Iteration over all orientation for orientation in range(30, 360, 30): f = open( directory + '/Data/m' + str(monkey + 1) + 'd' + str(orientation) +
def generate_data_figure1(data_path = '../Data/'): N, O, R, T = 15, 2, 200, 500 mu = numpy.zeros(T) x = numpy.arange(1, 401) mu[100:] = 1. * (3. / (2. * numpy.pi * (x/400.*3.) ** 3)) ** .5 * \ numpy.exp(-3. * ((x/400.*3.) - 1.) ** 2 / (2. * (x/400.*3.))) theta1 = synthesis.generate_thetas(N, O, T, mu1=-2.) theta2 = synthesis.generate_thetas(N, O, T, mu1=-2.) theta1[:, :N] += mu[:, numpy.newaxis] theta2[:, :N] += mu[:, numpy.newaxis] D = transforms.compute_D(N * 2, O) theta_all = numpy.empty([T, D]) theta_all[:, :N] = theta1[:, :N] theta_all[:, N:2 * N] = theta2[:, :N] triu_idx = numpy.triu_indices(N, k=1) triu_idx_all = numpy.triu_indices(2 * N, k=1) for t in range(T): theta_ij = numpy.zeros([2 * N, 2 * N]) theta_ij[triu_idx] = theta1[t, N:] theta_ij[triu_idx[0] + N, triu_idx[1] + N] = theta2[t, N:] theta_all[t, 2 * N:] = theta_ij[triu_idx_all] psi1 = numpy.empty([T, 3]) psi2 = numpy.empty([T, 3]) eta1 = numpy.empty(theta1.shape) eta2 = numpy.empty(theta2.shape) alpha = [.999,1.,1.001] transforms.initialise(N, O) for i in range(T): for j, a in enumerate(alpha): psi1[i, j] = transforms.compute_psi(a * theta1[i]) p = transforms.compute_p(theta1[i]) eta1[i] = transforms.compute_eta(p) for j, a in enumerate(alpha): psi2[i, j] = transforms.compute_psi(a * theta2[i]) p = transforms.compute_p(theta2[i]) eta2[i] = transforms.compute_eta(p) psi_all = psi1 + psi2 S1 = -numpy.sum(eta1 * theta1, axis=1) + psi1[:, 1] S1 /= numpy.log(2) S2 = -numpy.sum(eta2 * theta2, axis=1) + psi2[:, 1] S2 /= numpy.log(2) S_all = S1 + S2 C1 = (psi1[:, 0] - 2. * psi1[:, 1] + psi1[:, 2]) / .001 ** 2 C1 /= numpy.log(2) C2 = (psi2[:, 0] - 2. * psi2[:, 1] + psi2[:, 2]) / .001 ** 2 C2 /= numpy.log(2) C_all = C1 + C2 spikes = synthesis.generate_spikes_gibbs_parallel(theta_all, 2 * N, O, R, sample_steps=10, num_proc=4) print 'Model and Data generated' emd = __init__.run(spikes, O, map_function='cg', param_est='pseudo', param_est_eta='bethe_hybrid', lmbda1=100, lmbda2=200) f = h5py.File(data_path + 'figure1data.h5', 'w') g_data = f.create_group('data') g_data.create_dataset('theta_all', data=theta_all) g_data.create_dataset('psi_all', data=psi_all) g_data.create_dataset('S_all', data=S_all) g_data.create_dataset('C_all', data=C_all) g_data.create_dataset('spikes', data=spikes) g_data.create_dataset('theta1', data=theta1) g_data.create_dataset('theta2', data=theta2) g_data.create_dataset('psi1', data=psi1) g_data.create_dataset('S1', data=S1) g_data.create_dataset('C1', data=C1) g_data.create_dataset('psi2', data=psi2) g_data.create_dataset('S2', data=S2) g_data.create_dataset('C2', data=C2) g_fit = f.create_group('fit') g_fit.create_dataset('theta_s', data=emd.theta_s) g_fit.create_dataset('sigma_s', data=emd.sigma_s) g_fit.create_dataset('Q', data=emd.Q) f.close() print 'Fit and saved' f = h5py.File(data_path + 'figure1data.h5', 'r+') g_fit = f['fit'] theta = g_fit['theta_s'].value sigma = g_fit['sigma_s'].value X = numpy.random.randn(theta.shape[0], theta.shape[1], 100) theta_sampled = \ theta[:, :, numpy.newaxis] + X * numpy.sqrt(sigma)[:, :, numpy.newaxis] T = range(theta.shape[0]) eta_sampled = numpy.empty([theta.shape[0], theta.shape[1], 100]) psi_sampled = numpy.empty([theta.shape[0], 100, 3]) func = partial(get_sampled_eta_psi, theta_sampled=theta_sampled, N=2*N) pool = multiprocessing.Pool(10) results = pool.map(func, T) for eta, psi, i in results: eta_sampled[i] = eta psi_sampled[i] = psi S_sampled = \ -(numpy.sum(eta_sampled*theta_sampled, axis=1) - psi_sampled[:, :, 1]) S_sampled /= numpy.log(2) C_sampled = \ (psi_sampled[:, :, 0] - 2.*psi_sampled[:, :, 1] + psi_sampled[:, :, 2])/.001**2 C_sampled /= numpy.log(2) g_sampled = f.create_group('sampled_results') g_sampled.create_dataset('theta_sampled', data=theta_sampled) g_sampled.create_dataset('eta_sampled', data=eta_sampled) g_sampled.create_dataset('psi_sampled', data=psi_sampled) g_sampled.create_dataset('S_sampled', data=S_sampled) g_sampled.create_dataset('C_sampled', data=C_sampled) f.close() print 'Done'
def plot_figure1(data_path='../Data/', plot_path='../Plots/'): N, O = 30, 2 f = h5py.File(data_path + 'figure1data.h5', 'r') # Figure A fig = pyplot.figure(figsize=(30, 20)) ax = fig.add_axes([0.07, 0.68, .4, .4]) ax.imshow(-f['data']['spikes'][:, 0, :].T, cmap='gray', aspect=5, interpolation='nearest') ax.set_xticks([]) ax.set_yticks([]) ax = fig.add_axes([.06, 0.65, .4, .4]) ax.imshow(-f['data']['spikes'][:, 1, :].T, cmap='gray', aspect=5, interpolation='nearest') ax.set_xticks([]) ax.set_yticks([]) ax = fig.add_axes([.05, 0.62, .4, .4]) ax.imshow(-f['data']['spikes'][:, 2, :].T, cmap='gray', aspect=5, interpolation='nearest') ax.set_xticks([]) ax.set_yticks([]) ax.set_xlabel('Time [AU]', fontsize=26) ax.set_ylabel('Neuron ID', fontsize=26) ax = fig.add_axes([.05, 0.5, .4, .2]) ax.set_frame_on(False) ax.plot(numpy.mean(numpy.mean(f['data']['spikes'][:, :, :], axis=1), axis=1), linewidth=4, color='k') ymin, ymax = ax.get_yaxis().get_view_interval() xmin, xmax = ax.get_xaxis().get_view_interval() ax.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black', linewidth=2)) ax.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black', linewidth=3)) ax.yaxis.set_ticks_position('left') ax.xaxis.set_ticks_position('bottom') ax.set_yticks([.1, .2, .3]) ax.set_xticks([50, 150, 300]) ax.set_ylabel('Data $p_{\\mathrm{spike}}$', fontsize=26) ax.set_xlabel('Time [AU]', fontsize=26) ax.tick_params(axis='both', which='major', labelsize=20) theta = f['fit']['theta_s'].value sigma_s = f['fit']['sigma_s'].value bounds = numpy.empty([theta.shape[0], theta.shape[1] - N, 2]) bounds[:, :, 0] = theta[:, N:] - 2.58 * numpy.sqrt(sigma_s[:, N:]) bounds[:, :, 1] = theta[:, N:] + 2.58 * numpy.sqrt(sigma_s[:, N:]) # Figure B (Networks) graph_ax = [fig.add_axes([.52, 0.78, .15, .2]), fig.add_axes([.67, 0.78, .15, .2]), fig.add_axes([.82, 0.78, .15, .2])] T = [50, 150, 300] for i, t in enumerate(T): idx = numpy.where(numpy.logical_or(bounds[t, :, 0] > 0, bounds[t, :, 1] < 0))[0] conn_idx_all = numpy.arange(0, N * (N - 1) / 2) conn_idx = conn_idx_all[idx] all_conns = itertools.combinations(range(N), 2) conns = numpy.array(list(all_conns))[conn_idx] G1 = nx.Graph() G1.add_nodes_from(range(N)) # conns = itertools.combinations(range(30),2) G1.add_edges_from(conns) pos1 = nx.circular_layout(G1) net_nodes = \ nx.draw_networkx_nodes(G1, pos1, ax=graph_ax[i], node_color=theta[t, :N], cmap=pyplot.get_cmap('hot'), vmin=-3, vmax=-1.) e1 = nx.draw_networkx_edges(G1, pos1, ax=graph_ax[i], edge_color=theta[t, conn_idx].tolist(), edge_cmap=pyplot.get_cmap('seismic'), edge_vmin=-.7, edge_vmax=.7, width=2) graph_ax[i].axis('off') x0, x1 = graph_ax[i].get_xlim() y0, y1 = graph_ax[i].get_ylim() graph_ax[i].set_aspect(abs(x1 - x0) / abs(y1 - y0)) graph_ax[i].set_title('t=%d' % t, fontsize=24) cbar_ax = fig.add_axes([0.62, 0.79, 0.1, 0.01]) cbar_ax.tick_params(axis='both', which='major', labelsize=20) cbar = fig.colorbar(net_nodes, cax=cbar_ax, orientation='horizontal') cbar.set_ticks([-3, -2, -1]) cbar_ax.set_title('$\\theta_{i}$', fontsize=22) cbar_ax = fig.add_axes([0.77, 0.79, 0.1, 0.01]) cbar = fig.colorbar(e1, cax=cbar_ax, orientation='horizontal') cbar.set_ticks([-.5, 0., .5]) cbar_ax.set_title('$\\theta_{ij}$', fontsize=22) cbar_ax.tick_params(axis='both', which='major', labelsize=20) # Figure B (Thetas) theta = f['data']['theta_all'][:, [165, 170, 182]] theta_fit = f['fit']['theta_s'][:, [165, 170, 182]] sigma_fit = f['fit']['sigma_s'][:, [165, 170, 182]] ax1 = fig.add_axes([.55, 0.68, .4, .1]) ax1.set_frame_on(False) ax1.fill_between(range(0, 500), theta_fit[:, 0] - 2.58 * numpy.sqrt(sigma_fit[:, 0]), theta_fit[:, 0] + 2.58 * numpy.sqrt(sigma_fit[:, 0]), color=[.4, .4, .4]) ax1.plot(range(500), theta[:, 0], linewidth=4, color='k') ax1.set_yticks([-1, 0, 1]) ax1.set_ylim([-1.1, 1.1]) ymin, ymax = ax1.get_yaxis().get_view_interval() xmin, xmax = ax1.get_xaxis().get_view_interval() ax1.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black', linewidth=2)) ax1.set_xticks([]) ax1.yaxis.set_ticks_position('left') ax1.xaxis.set_ticks_position('bottom') ax1.tick_params(axis='both', which='major', labelsize=20) ax1 = fig.add_axes([.55, 0.57, .4, .1]) ax1.set_frame_on(False) ax1.fill_between(range(0, 500), theta_fit[:, 1] - 2.58 * numpy.sqrt(sigma_fit[:, 1]), theta_fit[:, 1] + 2.58 * numpy.sqrt(sigma_fit[:, 1]), color=[.5, .5, .5]) ax1.plot(range(500), theta[:, 1], linewidth=4, color='k') ax1.set_yticks([-1, 0, 1]) ax1.set_ylim([-1.1, 1.5]) ymin, ymax = ax1.get_yaxis().get_view_interval() xmin, xmax = ax1.get_xaxis().get_view_interval() ax1.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black', linewidth=2)) ax1.set_xticks([]) ax1.yaxis.set_ticks_position('left') ax1.xaxis.set_ticks_position('bottom') ax1.set_ylabel('$\\theta_{ij}$', fontsize=26) ax1.tick_params(axis='both', which='major', labelsize=20) ax1 = fig.add_axes([.55, 0.46, .4, .1]) ax1.set_frame_on(False) ax1.fill_between(range(0, 500), theta_fit[:, 2] - 2.58 * numpy.sqrt(sigma_fit[:, 2]), theta_fit[:, 2] + 2.58 * numpy.sqrt(sigma_fit[:, 2]), color=[.6, .6, .6]) ax1.plot(range(500), theta[:, 2], linewidth=4, color='k') ax1.set_ylim([-1.1, 1.1]) ymin, ymax = ax1.get_yaxis().get_view_interval() xmin, xmax = ax1.get_xaxis().get_view_interval() ax1.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black', linewidth=2)) ax1.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black', linewidth=3)) ax1.set_xticks([50, 150, 300]) ax1.yaxis.set_ticks_position('left') ax1.xaxis.set_ticks_position('bottom') ax1.set_xlabel('Time [AU]', fontsize=26) ax1.set_yticks([-1, 0, 1]) ax1.tick_params(axis='both', which='major', labelsize=20) # Figure C psi_color = numpy.array([51, 153., 255]) / 256. eta_color = numpy.array([0, 204., 102]) / 256. S_color = numpy.array([255, 162, 0]) / 256. C_color = numpy.array([204, 60, 60]) / 256. psi_quantiles = mquantiles(f['sampled_results']['psi_sampled'][:, :, 1], prob=[.01, .99], axis=1) psi_true = f['data']['psi_all'].value eta_quantiles = mquantiles(numpy.mean( f['sampled_results']['eta_sampled'][:, :N, :], axis=1), prob=[.01, .99], axis=1) C_quantiles = mquantiles(f['sampled_results']['C_sampled'][:, :], prob=[.01, .99], axis=1) C_true = f['data']['C_all'] S_quantiles = mquantiles(f['sampled_results']['S_sampled'][:, :], prob=[.01, .99], axis=1) S_true = f['data']['S_all'] eta1 = numpy.empty(f['data']['theta1'].shape) eta2 = numpy.empty(f['data']['theta2'].shape) T = eta1.shape[0] N1, N2 = 15, 15 transforms.initialise(N1, O) for i in range(T): p = transforms.compute_p(f['data']['theta1'][i]) eta1[i] = transforms.compute_eta(p) p = transforms.compute_p(f['data']['theta2'][i]) eta2[i] = transforms.compute_eta(p) ax1 = fig.add_axes([.08, 0.23, .4, .15]) ax1.set_frame_on(False) ax1.fill_between(range(0, 500), eta_quantiles[:, 0], eta_quantiles[:, 1], color=eta_color) eta_true = 1. / 2. * (numpy.mean(eta1[:, :N1], axis=1) + numpy.mean(eta2[:, :N2], axis=1)) ax1.fill_between(range(0, 500), eta_quantiles[:, 0], eta_quantiles[:, 1], color=eta_color) ax1.plot(range(500), eta_true, linewidth=4, color=eta_color * .8) ax1.set_yticks([.1, .2, .3]) ax1.set_ylim([.09, .35]) ymin, ymax = ax1.get_yaxis().get_view_interval() xmin, xmax = ax1.get_xaxis().get_view_interval() ax1.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black', linewidth=2)) ax1.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black', linewidth=3)) ax1.set_xticks([50, 150, 300]) ax1.yaxis.set_ticks_position('left') ax1.xaxis.set_ticks_position('bottom') ax1.set_ylabel('$p_{\\mathrm{spike}}$', fontsize=26) ax1.tick_params(axis='both', which='major', labelsize=20) ax1 = fig.add_axes([.08, 0.05, .4, .15]) ax1.set_frame_on(False) ax1.fill_between(range(0, 500), numpy.exp(-psi_quantiles[:, 0]), numpy.exp(-psi_quantiles[:, 1]), color=psi_color) ax1.plot(range(500), numpy.exp(-psi_true), linewidth=4, color=psi_color * .8) ax1.set_yticks([.0, .01, .02]) ax1.set_ylim([.0, .025]) ymin, ymax = ax1.get_yaxis().get_view_interval() xmin, xmax = ax1.get_xaxis().get_view_interval() ax1.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black', linewidth=2)) ax1.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black', linewidth=3)) ax1.set_xticks([50, 150, 300]) ax1.yaxis.set_ticks_position('left') ax1.xaxis.set_ticks_position('bottom') ax1.set_ylabel('$p_{\\mathrm{silence}}$', fontsize=26) ax1.set_xlabel('Time [AU]', fontsize=26) ax1.tick_params(axis='both', which='major', labelsize=20) # Entropy ax2 = fig.add_axes([.52, 0.23, .4, .15]) ax2.set_frame_on(False) ax2.fill_between(range(0, 500), S_quantiles[:, 0] / numpy.log2(numpy.exp(1)), S_quantiles[:, 1] / numpy.log2(numpy.exp(1)), color=S_color) ax2.plot(range(500), S_true / numpy.log2(numpy.exp(1)), linewidth=4, color=S_color * .8) ax2.set_xticks([50, 150, 300]) ax2.set_yticks([10, 14, 18]) ymin, ymax = ax2.get_yaxis().get_view_interval() xmin, xmax = ax2.get_xaxis().get_view_interval() ax2.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black', linewidth=2)) ax2.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black', linewidth=3)) ax2.yaxis.set_ticks_position('left') ax2.xaxis.set_ticks_position('bottom') ax2.set_ylabel('$S$', fontsize=26) ax2.tick_params(axis='both', which='major', labelsize=20) # Heat capacity ax2 = fig.add_axes([.52, 0.05, .4, .15]) ax2.set_frame_on(False) ax2.fill_between(range(0, 500), C_quantiles[:, 0] / numpy.log2(numpy.exp(1)), C_quantiles[:, 1] / numpy.log2(numpy.exp(1)), color=C_color) ax2.plot(range(500), C_true / numpy.log2(numpy.exp(1)), linewidth=5, color=C_color * .8) ymin, ymax = ax2.get_yaxis().get_view_interval() xmin, xmax = ax2.get_xaxis().get_view_interval() ax2.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black', linewidth=2)) ax2.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black', linewidth=3)) ax2.set_xticks([50, 150, 300]) ax2.set_yticks([5, 10]) ax2.yaxis.set_ticks_position('left') ax2.xaxis.set_ticks_position('bottom') ax2.set_xlabel('Time [AU]', fontsize=26) ax2.set_ylabel('$C$', fontsize=26) ax2.tick_params(axis='both', which='major', labelsize=20) ax = fig.add_axes([0.03, 0.95, .05, .05], frameon=0) ax.set_yticks([]) ax.set_xticks([]) ax.text(.0, .0, 'A', fontsize=26, fontweight='bold') ax = fig.add_axes([0.52, 0.95, .05, .05], frameon=0) ax.set_yticks([]) ax.set_xticks([]) ax.text(.0, .0, 'B', fontsize=26, fontweight='bold') ax = fig.add_axes([0.05, 0.4, .05, .05], frameon=0) ax.set_yticks([]) ax.set_xticks([]) ax.text(.0, .0, 'C', fontsize=26, fontweight='bold') fig.savefig(plot_path+'fig1.eps') pyplot.show()
O = 2 # ----- SPIKE SYNTHESIS ----- # Global module import numpy # Local modules import synthesis import transforms # Create underlying time-varying theta parameters as Gaussian processes # Create mean vector theta = synthesis.generate_thetas(N, O, T) # Initialise the transforms library in preparation for computing P transforms.initialise(N, O) # Compute P for each time step p = numpy.zeros((T, 2**N)) for i in range(T): p[i,:] = transforms.compute_p(theta[i,:]) # Generate spikes! spikes = synthesis.generate_spikes(p, R, seed=1) # ----- ALGORITHM EXECUTION ----- # Global module import numpy # Local module import __init__ # From outside this folder, this would be 'import ssll' # Run the algorithm!
def run(spikes, D, map_function='nr', window=1, lmbda=100, max_iter=1, \ J=None, theta_o=0, sigma_o=0.1, exact=True): """ Master-function of the State-Space Analysis of Spike Correlation package. Uses the expectation-maximisation algorithm to find the probability distributions of natural parameters of spike-train interactions over time. Calls slave functions to perform the expectation and maximisation steps repeatedly until the data likelihood reaches an asymptotic value. :param numpy.ndarray spikes: Binary matrix with dimensions (time, runs, cells), in which a `1' in location (t, r, c) denotes a spike at time t in run r by cell c. :param int window: Bin-width for counting spikes, in milliseconds. :param int D Number of graphs to consider in the fitting of the model :param string map_function: Name of the function to use for maximum a-posterior estimation of the natural parameters at each timestep. Refer to max_posterior.py. :param float lmdbda: Coefficient on the identity matrix of the initial state-transition covariance matrix. :param int max_iter: Maximum number of iterations for which to run the EM algorithm. :param numpy.ndarray J: First estimate of the graphs matrix :param boolean exact Whether to use exact solution (True) or Bethe approximation (False) for the computation of eta, psi and fisher info :returns: Results encapsulated in a container.EMData object, containing the smoothed posterior probability distributions of the natural parameters of the spike-train interactions at each timestep, conditional upon the given spikes. """ # Ensure NaNs are caught numpy.seterr(invalid='raise', under='raise') # Order of interactions considered in the graphs order_ising = 2 # Initialise the EM-data container map_func = max_posterior.functions[map_function] emd = container.EMData(spikes, window, D, map_func, lmbda, J,\ theta_o, sigma_o, exact) # Initialise the coordinate-transform maps if exact: transforms.initialise(emd.N, order_ising) # Set up loop guards for the EM algorithm lmp = -numpy.inf lmc = probability.log_marginal(emd) # Iterate the EM algorithm until convergence or failure while (emd.iterations < max_iter) and (emd.convergence > exp_max.CONVERGED): # Perform EM exp_max.e_step(emd) exp_max.m_step(emd) # Update previous and current log marginal values lmp = lmc lmc = probability.log_marginal(emd) # Update EM algorithm metadata emd.iterations += 1 emd.convergence = abs(1 - lmp / lmc) return emd