Beispiel #1
0
def run(spikes, order, window=1, map_function='nr', lmbda=0.005, max_iter=10):
    """
    Master-function of the State-Space Analysis of Spike Correlation package.
    Uses the expectation-maximisation algorithm to find the probability
    distributions of natural parameters of spike-train interactions over time.
    Calls slave functions to perform the expectation and maximisation steps
    repeatedly until the data likelihood reaches an asymptotic value.

    Note that the execution of some slave functions to this master function are
    of exponential complexity with respect to the `order' parameter.

    :param numpy.ndarray spikes:
        Binary matrix with dimensions (time, runs, cells), in which a `1' in
        location (t, r, c) denotes a spike at time t in run r by cell c.
    :param int order:
        Order of spike-train interactions to estimate, for example, 2 =
        pairwise, 3 = triplet-wise...
    :param int window:
        Bin-width for counting spikes, in milliseconds.
    :param string map_function:
        Name of the function to use for maximum a-posterior estimation of the
        natural parameters at each timestep. Refer to max_posterior.py.
    :param float lmdbda:
        Coefficient on the identity matrix of the initial state-transition
        covariance matrix.
    :param int max_iter:
        Maximum number of iterations for which to run the EM algorithm.

    :returns:
        Results encapsulated in a container.EMData object, containing the
        smoothed posterior probability distributions of the natural parameters
        of the spike-train interactions at each timestep, conditional upon the
        given spikes.
    """
    # Ensure NaNs are caught
    numpy.seterr(invalid='raise', under='raise')
    # Initialise the EM-data container
    map_func = max_posterior.functions[map_function]
    emd = container.EMData(spikes, order, window, map_func, lmbda)
    # Initialise the coordinate-transform maps
    transforms.initialise(emd.N, emd.order)
    # Set up loop guards for the EM algorithm
    lmp = -numpy.inf
    lmc = probability.log_marginal(emd)
    # Iterate the EM algorithm until convergence or failure
    while (emd.iterations < max_iter) and (emd.convergence > exp_max.CONVERGED):
        # Perform EM
        exp_max.e_step(emd)
        exp_max.m_step(emd)
        # Update previous and current log marginal values
        lmp = lmc
        lmc = probability.log_marginal(emd)
        # Update EM algorithm metadata
        emd.iterations += 1
        emd.convergence = lmp / lmc

    return emd
Beispiel #2
0
    def __init__(self, spikes, order, window, param_est, param_est_eta, map_function,
                 lmbda1, lmbda2):

        # Record the input parameters
        self.spikes, self.order, self.window = spikes, order, window
        T, self.R, self.N = self.spikes.shape
        if param_est == 'exact':
            transforms.initialise(self.N, self.order)
            self.max_posterior = max_posterior.functions[map_function]
        elif param_est == 'pseudo':
            pseudo_likelihood.compute_Fx_s(self.spikes, self.order)
            self.max_posterior = pseudo_likelihood.functions[map_function]

        self.param_est_theta = param_est
        self.param_est_eta = param_est_eta

        self.marg_llk = log_marginal_functions[param_est_eta]
        # Compute the `sample' spike-train interactions from the input spikes
        self.y = transforms.compute_y(self.spikes, self.order, self.window)
        # Count timesteps, trials, cells and interaction dimensions

        self.T, self.D = self.y.shape
        assert self.T == T / window
        # Initialise one-step-prediction- filtered- smoothed-density means
        self.theta_o = numpy.zeros((self.T,self.D))
        self.theta_f = numpy.zeros((self.T,self.D))
        self.theta_s = numpy.zeros((self.T,self.D))


        # Initialise covariances of the same (an I-matrix for each timestep)
        if param_est == 'exact':
            I = [numpy.identity(self.D) for i in range(self.T)]
            I = numpy.vstack(I).reshape((self.T,self.D,self.D))
            self.sigma_o = .1 * I
            self.sigma_o_inv = 1./.1 * I
            del I
            # Intialise autoregressive and transition probability hyperparameters
            self.sigma_f = numpy.copy(self.sigma_o)
            self.sigma_s = numpy.copy(self.sigma_o)
            self.sigma_s_lag = numpy.copy(self.sigma_o)
        # For approximate term initialize only the diagonal of the convariances
        else:
            self.sigma_o = .1*numpy.ones((self.T,self.D))
            self.sigma_o_inv = 1./.1*numpy.ones((self.T,self.D))
            self.sigma_f = .1*numpy.ones((self.T,self.D))
            self.sigma_s = .1*numpy.ones((self.T,self.D))
            self.sigma_s_lag = .1*numpy.ones((self.T,self.D))
        self.F = numpy.identity(self.D)
        self.Q = numpy.zeros([self.D, self.D])
        self.Q[:self.N, :self.N] = 1. / lmbda1 * numpy.identity(self.N)
        self.Q[self.N:, self.N:] = 1. / lmbda2 * numpy.identity(self.D - self.N)
        self.mllk = numpy.inf
        # Metadata about EM algorithm execution
        self.iterations, self.convergence = 0, numpy.inf
    def __init__(self, spikes, order, window, param_est, param_est_eta, map_function,
                 lmbda1, lmbda2, theta_o, sigma_o):

        # Record the input parameters
        self.spikes, self.order, self.window = spikes, order, window
        T, self.R, self.N = self.spikes.shape
        if param_est == 'exact':
            transforms.initialise(self.N, self.order)
            self.max_posterior = max_posterior.functions[map_function]
        elif param_est == 'pseudo':
            pseudo_likelihood.compute_Fx_s(self.spikes, self.order)
            self.max_posterior = pseudo_likelihood.functions[map_function]

        self.param_est_theta = param_est
        self.param_est_eta = param_est_eta

        self.marg_llk = log_marginal_functions[param_est_eta]
        # Compute the `sample' spike-train interactions from the input spikes
        self.y = transforms.compute_y(self.spikes, self.order, self.window)
        # Count timesteps, trials, cells and interaction dimensions

        self.T, self.D = self.y.shape
        assert self.T == T / window
        # Initialise one-step-prediction- filtered- smoothed-density means
        self.theta_o = numpy.ones((self.T,self.D)) * theta_o
        self.theta_f = numpy.zeros((self.T,self.D))
        self.theta_s = numpy.zeros((self.T,self.D))


        # Initialise covariances of the same (an I-matrix for each timestep)
        if param_est == 'exact':
            I = [numpy.identity(self.D) for i in range(self.T)]
            I = numpy.vstack(I).reshape((self.T,self.D,self.D))
            self.sigma_o = sigma_o * I
            self.sigma_o_inv = numpy.linalg.inv(self.sigma_o)
            del I
            # Intialise autoregressive and transition probability hyperparameters
            self.sigma_f = numpy.copy(self.sigma_o)
            self.sigma_s = numpy.copy(self.sigma_o)
            self.sigma_s_lag = numpy.copy(self.sigma_o)
        # For approximate term initialize only the diagonal of the covariances
        else:
            self.sigma_o = sigma_o * numpy.ones((self.T,self.D))
            self.sigma_o_inv = 1./ sigma_o * numpy.ones((self.T,self.D))
            self.sigma_f = .1*numpy.ones((self.T,self.D))
            self.sigma_s = .1*numpy.ones((self.T,self.D))
            self.sigma_s_lag = .1*numpy.ones((self.T,self.D))
        self.F = numpy.identity(self.D)
        self.Q = numpy.zeros([self.D, self.D])
        self.Q[:self.N, :self.N] = 1. / lmbda1 * numpy.identity(self.N)
        self.Q[self.N:, self.N:] = 1. / lmbda2 * numpy.identity(self.D - self.N)
        self.mllk = numpy.inf
        # Metadata about EM algorithm execution
        self.iterations, self.convergence = 0, numpy.inf
def compute_eta(theta, N, O, R=1000):
    """ Computes eta from given theta.

    :param numpy.ndarray theta:
        (t, d) array with natural parameters
    :param int N:
        number of cells
    :param int O:
        order of model
    :param int R:
        trials that should be sampled to estimate eta
    :return numpy.ndarray, list:
        (t, d) array with natural parameters parameters and a list with indices of bins, for which has been sampled

    Details: Tries to estimate eta by solving the forward problem from TAP. However, if it fails we fall back to
    sampling. For networks with less then 15 neurons exact solution is computed and for first order analytical solution
    is used.
    """
    T, D = theta.shape
    eta = numpy.empty(theta.shape)
    bins_to_sample = []
    if O == 1:
        eta = compute_ind_eta(theta[:, :N])
    elif O == 2:
        # if few cells compute exact rates
        if N > 15:
            for i in range(T):
                # try to solve forward problem
                try:
                    eta[i] = mean_field.forward_problem(theta[i], N, 'TAP')
                # if it fails remember bin for sampling
                except Exception:
                    bins_to_sample.append(i)
            if len(bins_to_sample) != 0:
                theta_to_sample = numpy.empty([len(bins_to_sample), D])
                for idx, bin2sampl in enumerate(bins_to_sample):
                    theta_to_sample[idx] = theta[bin2sampl]
                spikes = synthesis.generate_spikes_gibbs_parallel(
                    theta_to_sample, N, O, R, sample_steps=100)
                eta_from_sample = transforms.compute_y(spikes, O, 1)
                for idx, bin2sampl in enumerate(bins_to_sample):
                    eta[bin2sampl] = eta_from_sample[idx]

        # if large ensemble approximate
        else:
            transforms.initialise(N, O)
            for i in range(T):
                p = transforms.compute_p(theta[i])
                eta[i] = transforms.compute_eta(p)

    return eta, bins_to_sample
Beispiel #5
0
def compute_eta(theta, N, O, R=1000):
    """ Computes eta from given theta.

    :param numpy.ndarray theta:
        (t, d) array with natural parameters
    :param int N:
        number of cells
    :param int O:
        order of model
    :param int R:
        trials that should be sampled to estimate eta
    :return numpy.ndarray, list:
        (t, d) array with natural parameters parameters and a list with indices of bins, for which has been sampled

    Details: Tries to estimate eta by solving the forward problem from TAP. However, if it fails we fall back to
    sampling. For networks with less then 15 neurons exact solution is computed and for first order analytical solution
    is used.
    """
    T, D = theta.shape
    eta = numpy.empty(theta.shape)
    bins_to_sample = []
    if O == 1:
        eta = compute_ind_eta(theta[:,:N])
    elif O == 2:
        # if few cells compute exact rates
        if N > 15:
            for i in range(T):
                # try to solve forward problem
                try:
                    eta[i] = mean_field.forward_problem(theta[i], N, 'TAP')
                # if it fails remember bin for sampling
                except Exception:
                    bins_to_sample.append(i)
            if len(bins_to_sample) != 0:
                theta_to_sample = numpy.empty([len(bins_to_sample), D])
                for idx, bin2sampl in enumerate(bins_to_sample):
                    theta_to_sample[idx] = theta[bin2sampl]
                spikes = synthesis.generate_spikes_gibbs_parallel(theta_to_sample, N, O, R, sample_steps=100)
                eta_from_sample = transforms.compute_y(spikes, O, 1)
                for idx, bin2sampl in enumerate(bins_to_sample):
                    eta[bin2sampl] = eta_from_sample[idx]

        # if large ensemble approximate
        else:
            transforms.initialise(N, O)
            for i in range(T):
                p = transforms.compute_p(theta[i])
                eta[i] = transforms.compute_eta(p)

    return eta, bins_to_sample
Beispiel #6
0
 def run_ssasc(self, theta, N, O):
     # Initialise the library for computing pattern probabilities
     transforms.initialise(N, O)
     # Compute probability from theta values
     p = numpy.zeros((self.T, 2**N))
     for i in xrange(self.T):
         p[i,:] = transforms.compute_p(theta[i,:])
     # Generate spikes according to those probabilities
     spikes = synthesis.generate_spikes(p, self.R, seed=self.spike_seed)
     # Run the algorithm!
     emd = __init__.run(spikes, O)
     # Compute the KL divergence between real and estimated parameters
     kld = klic(theta, emd.theta_s, emd.N)
     # Check that KL divergence is OK
     if numpy.any(kld[50:-50] > .01):
         self.plot(theta, emd.theta_s, emd.sigma_s, emd.y, kld, emd.N, emd.T,
             emd.D)
     self.assertFalse(numpy.any(kld[50:-50] > .01))
Beispiel #7
0
def compute_psi(theta, N, O, R=1000):
    """ Computes psi from given theta.

    :param numpy.ndarray theta:
        (t, d) array with natural parameters
    :param int N:
        number of cells
    :param int O:
        order of model
    :param int R:
        trials that should be sampled to estimate eta
    :return numpy.ndarray, list:
        (t, d) array with log-partition and a list with indices of bins, for which has been sampled

    For first order the analytical solution is used. For networks with 15 units and less the exact solution is computed.
    Otherwise, the Ogata-Tanemura-Estimator is used. It tries to solve the forward problem and samples where it fails.
    """
    T = theta.shape[0]
    bins_sampled = []
    psi = numpy.empty(T)

    if O == 1:
        psi = compute_ind_psi(theta[:,:N])
    if O == 2:
        # if few cells compute exact result
        if N > 15:
            theta0 = numpy.copy(theta)
            theta0[:,N:] = 0
            psi0 = compute_ind_psi(theta0[:,:N])
            for i in range(T):
                psi[i], sampled = ot_estimator(theta0[i], psi0[i], theta[i], N, O, N)
                # save bin if sampled
                if sampled:
                    bins_sampled.append(i)
        # else approximate
        else:
            transforms.initialise(N, 2)
            for i in range(T):
                psi[i] = transforms.compute_psi(theta[i])
    return psi, bins_sampled
def compute_psi(theta, N, O, R=1000):
    """ Computes psi from given theta.

    :param numpy.ndarray theta:
        (t, d) array with natural parameters
    :param int N:
        number of cells
    :param int O:
        order of model
    :param int R:
        trials that should be sampled to estimate eta
    :return numpy.ndarray, list:
        (t, d) array with log-partition and a list with indices of bins, for which has been sampled

    For first order the analytical solution is used. For networks with 15 units and less the exact solution is computed.
    Otherwise, the Ogata-Tanemura-Estimator is used. It tries to solve the forward problem and samples where it fails.
    """
    T = theta.shape[0]
    bins_sampled = []
    psi = numpy.empty(T)

    if O == 1:
        psi = compute_ind_psi(theta[:, :N])
    if O == 2:
        # if few cells compute exact result
        if N > 15:
            theta0 = numpy.copy(theta)
            theta0[:, N:] = 0
            psi0 = compute_ind_psi(theta0[:, :N])
            for i in range(T):
                psi[i] = ot_estimator(theta0[i], psi0[i], theta[i], N, O, N)
        # else approximate
        else:
            transforms.initialise(N, 2)
            for i in range(T):
                psi[i] = transforms.compute_psi(theta[i])
    return psi
Beispiel #9
0
# Number of dimensions of the generative model
D = 2

# Number of graphs in the fitted model
D_fit = 2

# Precision
lmbda = 1000

# Random seed
seed = numpy.random.seed(0)

# p_map and e_map (Initialise)
if exact:
    transforms.initialise(N, 2)

# Number of dimensions of the graphs
dim = transforms.compute_D(N, 2)

# Generative J matrix
J_gen = numpy.zeros((dim, D_fit))
x1 = numpy.array([-1, 1, -1, 1, 1, 1, -1, 1, -1])
x2 = numpy.array([1, 1, 1, -1, 1, -1, -1, 1, -1])
a1 = numpy.outer(x1, x1)
a2 = numpy.outer(x2, x2)
J_gen[:N, 0] = numpy.diag(a1)
J_gen[:N, 1] = numpy.diag(a2)
J_gen[N:, 0] = a1[numpy.triu_indices(N, k=1)]
J_gen[N:, 1] = a2[numpy.triu_indices(N, k=1)]
J_gen = J_gen / (numpy.var(J_gen, axis=0)**0.5)
Beispiel #10
0
def generate_data_ctime(data_path='../Data/', max_network_size=60,
                              num_procs=4):
    N, O, R, T = 10, 2, 200, 500
    num_of_networks = max_network_size/N
    mu = numpy.zeros(T)
    x = numpy.arange(1, 401)
    mu[100:] = 1. * (3. / (2. * numpy.pi * (x / 400. * 3.) ** 3)) ** .5 * \
               numpy.exp(-3. * ((x / 400. * 3.) - 1.) ** 2 /
                         (2. * (x / 400. * 3.)))

    D = transforms.compute_D(N, O)
    thetas = numpy.empty([num_of_networks, T, D])
    transforms.initialise(N, O)
    for i in range(num_of_networks):
        thetas[i] = synthesis.generate_thetas(N, O, T, mu1=-2.)
        thetas[i, :, :N] += mu[:, numpy.newaxis]

    R = 500
    f = h5py.File(data_path + 'comp_time_data.h5', 'w')
    f.create_dataset('N', data=numpy.arange(N, max_network_size+N, N))
    f.create_dataset('ctime', shape=[2,num_of_networks])
    f.close()
    for i in range(num_of_networks):
        print 'N=%d' % ((i + 1) * N)
        D = transforms.compute_D((i + 1) * N, O)
        theta_all = numpy.empty([T, D])
        triu_idx = numpy.triu_indices(N, k=1)
        triu_idx_all = numpy.triu_indices((i + 1) * N, k=1)

        for j in range(i + 1):
            theta_all[:, N * j:(j + 1) * N] = thetas[j, :, :N]

        for t in range(T):
            theta_ij = numpy.zeros([(i + 1) * N, (i + 1) * N])
            for j in range(i + 1):
                theta_ij[triu_idx[0] + j * N, triu_idx[1] + j * N] = \
                    thetas[j, t, N:]

            theta_all[t, (i + 1) * N:] = theta_ij[triu_idx_all]

        spikes = synthesis.generate_spikes_gibbs_parallel(theta_all,
                                                          (i + 1) * N, O, R,
                                                          sample_steps=10,
                                                          num_proc=num_procs)
        t1 = time.time()
        result = __init__.run(spikes, O, map_function='cg',
                                    param_est='pseudo',
                           param_est_eta='bethe_hybrid',
                            lmbda1=100,
                           lmbda2=200)
        t2 = time.time()
        ctime_bethe = t2 - t1

        f = h5py.File(data_path + 'comp_time_data.h5', 'r+')
        f['ctime'][0, i] = ctime_bethe
        f.close()

        try:
            t1 = time.time()
            result = __init__.run(spikes, O, map_function='cg',
                                                   param_est='pseudo',
                                                   param_est_eta='mf',
                                                   lmbda1=100,
                                                   lmbda2=200)
            t2 = time.time()
            ctime_TAP = t2 - t1
        except Exception:
            ctime_TAP = numpy.nan

        f = h5py.File(data_path + 'comp_time_data.h5', 'r+')
        f['ctime'][1, i] = ctime_TAP
        f.close()
Beispiel #11
0
def generate_data_figure3and4(data_path = '../Data/', num_of_iterations=10):
    R, T, N, O = 200, 500, 15, 2
    f = h5py.File(data_path + 'figure1data.h5', 'r')
    theta = f['data']['theta1'].value
    f.close()

    transforms.initialise(N, O)
    psi_true = numpy.empty(T)
    for i in range(T):
        psi_true[i] = transforms.compute_psi(theta[i])
    p = numpy.zeros((T, 2 ** N))
    for i in range(T):
        p[i, :] = transforms.compute_p(theta[i, :])
    fitting_methods = ['exact', 'bethe_hybrid', 'mf']

    f = h5py.File(data_path + 'figure2and3data.h5', 'w')
    f.create_dataset('psi_true', data=psi_true)
    f.create_dataset('theta_true', data=theta)
    for fit in fitting_methods:
        g = f.create_group(fit)
        g.create_dataset('MISE_theta', shape=[num_of_iterations])
        g.create_dataset('MISE_psi', shape=[num_of_iterations])
        g.create_dataset('psi', shape=[num_of_iterations, T])
    f.close()

    for iteration in range(num_of_iterations):
        print 'Iteration %d' % iteration
        spikes = synthesis.generate_spikes(p, R, seed=None)

        for fit in fitting_methods:
            if fit == 'exact':
                emd = __init__.run(spikes, O, map_function='cg',
                                   param_est='exact', param_est_eta='exact')
            else:
                emd = __init__.run(spikes, O, map_function='cg',
                                   param_est='pseudo', param_est_eta=fit)

            psi = numpy.empty(T)

            if fit == 'exact':
                for i in range(T):
                    psi[i] = transforms.compute_psi(emd.theta_s[i])
            elif fit == 'bethe_hybrid':
                for i in range(T):
                    psi[i] = bethe_approximation.compute_eta_hybrid(
                        emd.theta_s[i], N, return_psi=1)[1]
            elif fit == 'mf':
                for i in range(T):
                    eta_mf = mean_field.forward_problem(emd.theta_s[i], N,
                                                        'TAP')
                    psi[i] = mean_field.compute_psi(emd.theta_s[i], eta_mf, N)

            mise_theta = numpy.mean((theta - emd.theta_s) ** 2)
            mise_psi = numpy.mean((psi_true - psi) ** 2)
            f = h5py.File(data_path + 'figure2and3data.h5', 'r+')
            g = f[fit]
            g['MISE_theta'][iteration] = mise_theta
            g['MISE_psi'][iteration] = mise_psi
            if iteration == 0:
                g.create_dataset('theta', data=emd.theta_s)
                g.create_dataset('sigma', data=emd.sigma_s)
            g['psi'][iteration] = psi
            f.close()
            print 'Fitted with %s' % fit
Beispiel #12
0
def generate_data_figure2(data_path='../Data/', max_network_size=60):
    N, O, R, T = 10, 2, 200, 500
    num_of_networks = max_network_size/N
    mu = numpy.zeros(T)
    x = numpy.arange(1, 401)
    mu[100:] = 1. * (3. / (2. * numpy.pi * (x / 400. * 3.) ** 3)) ** .5 * \
               numpy.exp(-3. * ((x / 400. * 3.) - 1.) ** 2 /
                         (2. * (x / 400. * 3.)))

    D = transforms.compute_D(N, O)
    thetas = numpy.empty([num_of_networks, T, D])
    etas = numpy.empty([num_of_networks, T, D])
    psi = numpy.empty([num_of_networks, T])
    S = numpy.empty([num_of_networks, T])
    C = numpy.empty([num_of_networks, T])
    transforms.initialise(N, O)
    for i in range(num_of_networks):
        thetas[i] = synthesis.generate_thetas(N, O, T, mu1=-2.)
        thetas[i, :, :N] += mu[:, numpy.newaxis]
        for t in range(T):
            p = transforms.compute_p(thetas[i, t])
            etas[i, t] = transforms.compute_eta(p)
            psi[i, t] = transforms.compute_psi(thetas[i, t])
            psi1 = transforms.compute_psi(.999 * thetas[i, t])
            psi2 = transforms.compute_psi(1.001 * thetas[i, t])
            C[i, t] = (psi1 - 2. * psi[i, t] + psi2) / .001 ** 2
            S[i, t] = -(numpy.sum(etas[i, t] * thetas[i, t]) - psi[i, t])
    C /= numpy.log(2)
    S /= numpy.log(2)
    f = h5py.File(data_path + 'figure2data.h5', 'w')
    g1 = f.create_group('data')
    g1.create_dataset('thetas', data=thetas)
    g1.create_dataset('etas', data=etas)
    g1.create_dataset('psi', data=psi)
    g1.create_dataset('S', data=S)
    g1.create_dataset('C', data=C)
    g2 = f.create_group('error')
    g2.create_dataset('MISE_thetas', shape=[num_of_networks])
    g2.create_dataset('MISE_population_rate', shape=[num_of_networks])
    g2.create_dataset('MISE_psi', shape=[num_of_networks])
    g2.create_dataset('MISE_S', shape=[num_of_networks])
    g2.create_dataset('MISE_C', shape=[num_of_networks])
    g2.create_dataset('population_rate', shape=[num_of_networks, T])
    g2.create_dataset('psi', shape=[num_of_networks, T])
    g2.create_dataset('S', shape=[num_of_networks, T])
    g2.create_dataset('C', shape=[num_of_networks, T])
    f.close()
    for i in range(num_of_networks):
        print 'N=%d' % ((i + 1) * N)
        D = transforms.compute_D((i + 1) * N, O)
        theta_all = numpy.empty([T, D])
        triu_idx = numpy.triu_indices(N, k=1)
        triu_idx_all = numpy.triu_indices((i + 1) * N, k=1)
        for j in range(i + 1):
            theta_all[:, N * j:(j + 1) * N] = thetas[j, :, :N]
            for t in range(T):
                theta_ij = numpy.zeros([(i + 1) * N, (i + 1) * N])
                for j in range(i + 1):
                    theta_ij[triu_idx[0] + j * N, triu_idx[1] + j * N] = \
                        thetas[j, t, N:]

            theta_all[t, (i + 1) * N:] = theta_ij[triu_idx_all]

        spikes = synthesis.generate_spikes_gibbs_parallel(theta_all
                                                          , (i + 1) * N, O, R,
                                                          sample_steps=10,
                                                          num_proc=4)
        emd = __init__.run(spikes, O, map_function='cg', param_est='pseudo',
                           param_est_eta='bethe_hybrid', lmbda1=100,
                           lmbda2=200)

        eta_est = numpy.empty(emd.theta_s.shape)
        psi_est = numpy.empty(T)
        S_est = numpy.empty(T)
        C_est = numpy.empty(T)
        for t in range(T):
            eta_est[t], psi_est[t] = bethe_approximation.compute_eta_hybrid(
                emd.theta_s[t], (i + 1) * N, return_psi=1)
            psi1 = bethe_approximation.compute_eta_hybrid(
                .999 * emd.theta_s[t], (i + 1) * N, return_psi=1)[1]
            psi2 = bethe_approximation.compute_eta_hybrid(
                1.001 * emd.theta_s[t], (i + 1) * N, return_psi=1)[1]
            S_est[t] = -(numpy.sum(eta_est[t] * emd.theta_s[t]) - psi_est[t])
            C_est[t] = (psi1 - 2. * psi_est[t] + psi2) / .001 ** 2
        S_est /= numpy.log(2)
        C_est /= numpy.log(2)
        population_rate = numpy.mean(numpy.mean(etas[:i + 1, :, :N], axis=0),
                                     axis=1)
        population_rate_est = numpy.mean(eta_est[:, :(i + 1) * N], axis=1)
        psi_true = numpy.sum(psi[:(i + 1), :], axis=0)
        S_true = numpy.sum(S[:(i + 1), :], axis=0)
        C_true = numpy.sum(C[:(i + 1), :], axis=0)

        f = h5py.File(data_path + 'figure2data.h5', 'r+')
        f['error']['MISE_thetas'][i] = numpy.mean(
            (theta_all - emd.theta_s) ** 2)
        f['error']['MISE_population_rate'][i] = numpy.mean(
            (population_rate - population_rate_est) ** 2)
        f['error']['MISE_psi'][i] = numpy.mean((psi_est - psi_true) ** 2)
        f['error']['MISE_S'][i] = numpy.mean((S_est - S_true) ** 2)
        f['error']['MISE_C'][i] = numpy.mean((C_est - C_true) ** 2)
        f['error']['population_rate'][i] = population_rate_est
        f['error']['psi'][i] = psi_est
        f['error']['S'][i] = S_est
        f['error']['C'][i] = C_est
        f.close()

    f = h5py.File(data_path + 'figure2data.h5', 'r+')
    thetas = f['data']['thetas'].value
    etas = f['data']['etas'].value
    psi = f['data']['psi'].value
    S = f['data']['S'].value
    C = f['data']['C'].value

    g2 = f.create_group('error500')
    g2.create_dataset('population_rate', shape=[num_of_networks, T])
    g2.create_dataset('psi', shape=[num_of_networks, T])
    g2.create_dataset('S', shape=[num_of_networks, T])
    g2.create_dataset('C', shape=[num_of_networks, T])
    g2.create_dataset('MISE_thetas', shape=[num_of_networks])
    g2.create_dataset('MISE_population_rate', shape=[num_of_networks])
    g2.create_dataset('MISE_psi', shape=[num_of_networks])
    g2.create_dataset('MISE_S', shape=[num_of_networks])
    g2.create_dataset('MISE_C', shape=[num_of_networks])
    f.close()

    R = 500

    for i in range(num_of_networks):
        print 'N=%d' % ((i + 1) * N)
        D = transforms.compute_D((i + 1) * N, O)
        theta_all = numpy.empty([T, D])
        triu_idx = numpy.triu_indices(N, k=1)
        triu_idx_all = numpy.triu_indices((i + 1) * N, k=1)

        for j in range(i + 1):
            theta_all[:, N * j:(j + 1) * N] = thetas[j, :, :N]

        for t in range(T):
            theta_ij = numpy.zeros([(i + 1) * N, (i + 1) * N])
            for j in range(i + 1):
                theta_ij[triu_idx[0] + j * N, triu_idx[1] + j * N] = \
                    thetas[j, t, N:]

            theta_all[t, (i + 1) * N:] = theta_ij[triu_idx_all]

        spikes = synthesis.generate_spikes_gibbs_parallel(theta_all,
                                                          (i + 1) * N, O, R,
                                                          sample_steps=10,
                                                          num_proc=4)
        emd = __init__.run(spikes, O, map_function='cg', param_est='pseudo',
                           param_est_eta='bethe_hybrid', lmbda1=100,
                           lmbda2=200)

        eta_est = numpy.empty(emd.theta_s.shape)
        psi_est = numpy.empty(T)
        S_est = numpy.empty(T)
        C_est = numpy.empty(T)

        for t in range(T):
            eta_est[t], psi_est[t] = \
                bethe_approximation.compute_eta_hybrid(emd.theta_s[t],
                                                       (i + 1) * N,
                                                       return_psi=1)
            psi1 = bethe_approximation.compute_eta_hybrid(.999 * emd.theta_s[t],
                                                          (i + 1) * N,
                                                          return_psi=1)[1]
            psi2 = bethe_approximation.compute_eta_hybrid(
                1.001 * emd.theta_s[t], (i + 1) * N, return_psi=1)[1]
            S_est[t] = -(numpy.sum(eta_est[t] * emd.theta_s[t]) - psi_est[t])
            C_est[t] = (psi1 - 2. * psi_est[t] + psi2) / .001 ** 2
        S_est /= numpy.log(2)
        C_est /= numpy.log(2)
        population_rate = numpy.mean(numpy.mean(etas[:i + 1, :, :N], axis=0),
                                     axis=1)
        population_rate_est = numpy.mean(eta_est[:, :(i + 1) * N], axis=1)
        psi_true = numpy.sum(psi[:(i + 1), :], axis=0)
        S_true = numpy.sum(S[:(i + 1), :], axis=0)
        C_true = numpy.sum(C[:(i + 1), :], axis=0)

        f = h5py.File(data_path + 'figure2data.h5', 'r+')
        f['error500']['MISE_thetas'][i] = numpy.mean(
            (theta_all - emd.theta_s) ** 2)
        f['error500']['MISE_population_rate'][i] = numpy.mean(
            (population_rate - population_rate_est) ** 2)
        f['error500']['MISE_psi'][i] = numpy.mean((psi_est - psi_true) ** 2)
        f['error500']['MISE_S'][i] = numpy.mean((S_est - S_true) ** 2)
        f['error500']['MISE_C'][i] = numpy.mean((C_est - C_true) ** 2)
        f['error500']['population_rate'][i] = population_rate_est
        f['error500']['psi'][i] = psi_est
        f['error500']['S'][i] = S_est
        f['error500']['C'][i] = C_est
        f.close()
Beispiel #13
0
# Number of time bins, trials and neurons
T, R, N = emd.spikes.shape

# Number of dimensions
D = emd.D

# Order of interactions considered in the Ising models
O = emd.order

# Initialise the theta matrix and the sigma matrix
theta_m = numpy.zeros((12, T, D))
theta_m[0, :, :] = emd.theta_s
sigma_m = numpy.zeros((12, T, D))
if emd.marg_llk == probability.log_marginal:  # i.e. if exact model is used
    transforms.initialise(N, O)
    for t in range(T):
        sigma_m[0, t] = numpy.diag(
            emd.sigma_s[t]
        )  # Because emd.sigma_s is ((T,D,D)) if exact, only need the diagonal for fig1
else:
    sigma_m[0] = emd.sigma_s

# Initialize the spike train array
spikes_m = numpy.zeros((12, T, R, N))
spikes_m[0, :, :, :] = emd.spikes

# Iteration over all orientation
for orientation in range(30, 360, 30):
    f = open(
        directory + '/Data/m' + str(monkey + 1) + 'd' + str(orientation) +
Beispiel #14
0
def generate_data_figure1(data_path = '../Data/'):
    N, O, R, T = 15, 2, 200, 500
    mu = numpy.zeros(T)
    x = numpy.arange(1, 401)
    mu[100:] = 1. * (3. / (2. * numpy.pi * (x/400.*3.) ** 3)) ** .5 * \
               numpy.exp(-3. * ((x/400.*3.) - 1.) ** 2 / (2. * (x/400.*3.)))
    theta1 = synthesis.generate_thetas(N, O, T, mu1=-2.)
    theta2 = synthesis.generate_thetas(N, O, T, mu1=-2.)
    theta1[:, :N] += mu[:, numpy.newaxis]
    theta2[:, :N] += mu[:, numpy.newaxis]
    D = transforms.compute_D(N * 2, O)
    theta_all = numpy.empty([T, D])
    theta_all[:, :N] = theta1[:, :N]
    theta_all[:, N:2 * N] = theta2[:, :N]
    triu_idx = numpy.triu_indices(N, k=1)
    triu_idx_all = numpy.triu_indices(2 * N, k=1)
    for t in range(T):
        theta_ij = numpy.zeros([2 * N, 2 * N])
        theta_ij[triu_idx] = theta1[t, N:]
        theta_ij[triu_idx[0] + N, triu_idx[1] + N] = theta2[t, N:]
        theta_all[t, 2 * N:] = theta_ij[triu_idx_all]

    psi1 = numpy.empty([T, 3])
    psi2 = numpy.empty([T, 3])
    eta1 = numpy.empty(theta1.shape)
    eta2 = numpy.empty(theta2.shape)
    alpha = [.999,1.,1.001]
    transforms.initialise(N, O)
    for i in range(T):
        for j, a in enumerate(alpha):
            psi1[i, j] = transforms.compute_psi(a * theta1[i])
        p = transforms.compute_p(theta1[i])
        eta1[i] = transforms.compute_eta(p)
        for j, a in enumerate(alpha):
            psi2[i, j] = transforms.compute_psi(a * theta2[i])
        p = transforms.compute_p(theta2[i])
        eta2[i] = transforms.compute_eta(p)

    psi_all = psi1 + psi2
    S1 = -numpy.sum(eta1 * theta1, axis=1) + psi1[:, 1]
    S1 /= numpy.log(2)
    S2 = -numpy.sum(eta2 * theta2, axis=1) + psi2[:, 1]
    S2 /= numpy.log(2)
    S_all = S1 + S2

    C1 = (psi1[:, 0] - 2. * psi1[:, 1] + psi1[:, 2]) / .001 ** 2
    C1 /= numpy.log(2)
    C2 = (psi2[:, 0] - 2. * psi2[:, 1] + psi2[:, 2]) / .001 ** 2
    C2 /= numpy.log(2)

    C_all = C1 + C2

    spikes = synthesis.generate_spikes_gibbs_parallel(theta_all, 2 * N, O, R,
                                                      sample_steps=10,
                                                      num_proc=4)

    print 'Model and Data generated'

    emd = __init__.run(spikes, O, map_function='cg', param_est='pseudo',
                       param_est_eta='bethe_hybrid', lmbda1=100, lmbda2=200)

    f = h5py.File(data_path + 'figure1data.h5', 'w')
    g_data = f.create_group('data')
    g_data.create_dataset('theta_all', data=theta_all)
    g_data.create_dataset('psi_all', data=psi_all)
    g_data.create_dataset('S_all', data=S_all)
    g_data.create_dataset('C_all', data=C_all)
    g_data.create_dataset('spikes', data=spikes)
    g_data.create_dataset('theta1', data=theta1)
    g_data.create_dataset('theta2', data=theta2)
    g_data.create_dataset('psi1', data=psi1)
    g_data.create_dataset('S1', data=S1)
    g_data.create_dataset('C1', data=C1)
    g_data.create_dataset('psi2', data=psi2)
    g_data.create_dataset('S2', data=S2)
    g_data.create_dataset('C2', data=C2)
    g_fit = f.create_group('fit')
    g_fit.create_dataset('theta_s', data=emd.theta_s)
    g_fit.create_dataset('sigma_s', data=emd.sigma_s)
    g_fit.create_dataset('Q', data=emd.Q)
    f.close()

    print 'Fit and saved'

    f = h5py.File(data_path + 'figure1data.h5', 'r+')
    g_fit = f['fit']
    theta = g_fit['theta_s'].value
    sigma = g_fit['sigma_s'].value

    X = numpy.random.randn(theta.shape[0], theta.shape[1], 100)
    theta_sampled = \
        theta[:, :, numpy.newaxis] + X * numpy.sqrt(sigma)[:, :, numpy.newaxis]

    T = range(theta.shape[0])
    eta_sampled = numpy.empty([theta.shape[0], theta.shape[1], 100])
    psi_sampled = numpy.empty([theta.shape[0], 100, 3])

    func = partial(get_sampled_eta_psi, theta_sampled=theta_sampled, N=2*N)
    pool = multiprocessing.Pool(10)
    results = pool.map(func, T)

    for eta, psi, i in results:
        eta_sampled[i] = eta
        psi_sampled[i] = psi
    S_sampled = \
        -(numpy.sum(eta_sampled*theta_sampled, axis=1) - psi_sampled[:, :, 1])
    S_sampled /= numpy.log(2)
    C_sampled = \
        (psi_sampled[:, :, 0] - 2.*psi_sampled[:, :, 1] +
         psi_sampled[:, :, 2])/.001**2
    C_sampled /= numpy.log(2)
    g_sampled = f.create_group('sampled_results')
    g_sampled.create_dataset('theta_sampled', data=theta_sampled)
    g_sampled.create_dataset('eta_sampled', data=eta_sampled)
    g_sampled.create_dataset('psi_sampled', data=psi_sampled)
    g_sampled.create_dataset('S_sampled', data=S_sampled)
    g_sampled.create_dataset('C_sampled', data=C_sampled)
    f.close()

    print 'Done'
Beispiel #15
0
def plot_figure1(data_path='../Data/', plot_path='../Plots/'):

    N, O = 30, 2
    f = h5py.File(data_path + 'figure1data.h5', 'r')
    # Figure A
    fig = pyplot.figure(figsize=(30, 20))

    ax = fig.add_axes([0.07, 0.68, .4, .4])
    ax.imshow(-f['data']['spikes'][:, 0, :].T, cmap='gray', aspect=5,
              interpolation='nearest')
    ax.set_xticks([])
    ax.set_yticks([])
    ax = fig.add_axes([.06, 0.65, .4, .4])
    ax.imshow(-f['data']['spikes'][:, 1, :].T, cmap='gray', aspect=5,
              interpolation='nearest')
    ax.set_xticks([])
    ax.set_yticks([])
    ax = fig.add_axes([.05, 0.62, .4, .4])
    ax.imshow(-f['data']['spikes'][:, 2, :].T, cmap='gray', aspect=5,
              interpolation='nearest')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel('Time [AU]', fontsize=26)
    ax.set_ylabel('Neuron ID', fontsize=26)

    ax = fig.add_axes([.05, 0.5, .4, .2])
    ax.set_frame_on(False)
    ax.plot(numpy.mean(numpy.mean(f['data']['spikes'][:, :, :], axis=1),
                       axis=1), linewidth=4, color='k')
    ymin, ymax = ax.get_yaxis().get_view_interval()
    xmin, xmax = ax.get_xaxis().get_view_interval()
    ax.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black',
                                linewidth=2))
    ax.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black',
                                linewidth=3))
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    ax.set_yticks([.1, .2, .3])
    ax.set_xticks([50, 150, 300])
    ax.set_ylabel('Data $p_{\\mathrm{spike}}$', fontsize=26)
    ax.set_xlabel('Time [AU]', fontsize=26)
    ax.tick_params(axis='both', which='major', labelsize=20)

    theta = f['fit']['theta_s'].value
    sigma_s = f['fit']['sigma_s'].value
    bounds = numpy.empty([theta.shape[0], theta.shape[1] - N, 2])
    bounds[:, :, 0] = theta[:, N:] - 2.58 * numpy.sqrt(sigma_s[:, N:])
    bounds[:, :, 1] = theta[:, N:] + 2.58 * numpy.sqrt(sigma_s[:, N:])
    # Figure B (Networks)
    graph_ax = [fig.add_axes([.52, 0.78, .15, .2]),
                fig.add_axes([.67, 0.78, .15, .2]),
                fig.add_axes([.82, 0.78, .15, .2])]
    T = [50, 150, 300]
    for i, t in enumerate(T):
        idx = numpy.where(numpy.logical_or(bounds[t, :, 0] > 0, bounds[t, :, 1]
                                           < 0))[0]
        conn_idx_all = numpy.arange(0, N * (N - 1) / 2)
        conn_idx = conn_idx_all[idx]
        all_conns = itertools.combinations(range(N), 2)
        conns = numpy.array(list(all_conns))[conn_idx]
        G1 = nx.Graph()
        G1.add_nodes_from(range(N))
        # conns = itertools.combinations(range(30),2)
        G1.add_edges_from(conns)
        pos1 = nx.circular_layout(G1)
        net_nodes = \
            nx.draw_networkx_nodes(G1, pos1, ax=graph_ax[i],
                                   node_color=theta[t, :N],
                                   cmap=pyplot.get_cmap('hot'), vmin=-3,
                                   vmax=-1.)
        e1 = nx.draw_networkx_edges(G1, pos1, ax=graph_ax[i],
                                    edge_color=theta[t, conn_idx].tolist(),
                                    edge_cmap=pyplot.get_cmap('seismic'),
                                    edge_vmin=-.7, edge_vmax=.7, width=2)
        graph_ax[i].axis('off')
        x0, x1 = graph_ax[i].get_xlim()
        y0, y1 = graph_ax[i].get_ylim()
        graph_ax[i].set_aspect(abs(x1 - x0) / abs(y1 - y0))
        graph_ax[i].set_title('t=%d' % t, fontsize=24)
    cbar_ax = fig.add_axes([0.62, 0.79, 0.1, 0.01])
    cbar_ax.tick_params(axis='both', which='major', labelsize=20)
    cbar = fig.colorbar(net_nodes, cax=cbar_ax, orientation='horizontal')
    cbar.set_ticks([-3, -2, -1])
    cbar_ax.set_title('$\\theta_{i}$', fontsize=22)
    cbar_ax = fig.add_axes([0.77, 0.79, 0.1, 0.01])
    cbar = fig.colorbar(e1, cax=cbar_ax, orientation='horizontal')
    cbar.set_ticks([-.5, 0., .5])
    cbar_ax.set_title('$\\theta_{ij}$', fontsize=22)
    cbar_ax.tick_params(axis='both', which='major', labelsize=20)

    # Figure B (Thetas)
    theta = f['data']['theta_all'][:, [165, 170, 182]]
    theta_fit = f['fit']['theta_s'][:, [165, 170, 182]]
    sigma_fit = f['fit']['sigma_s'][:, [165, 170, 182]]
    ax1 = fig.add_axes([.55, 0.68, .4, .1])
    ax1.set_frame_on(False)
    ax1.fill_between(range(0, 500), theta_fit[:, 0] - 2.58 *
                     numpy.sqrt(sigma_fit[:, 0]), theta_fit[:, 0] + 2.58 *
                     numpy.sqrt(sigma_fit[:, 0]), color=[.4, .4, .4])
    ax1.plot(range(500), theta[:, 0], linewidth=4, color='k')
    ax1.set_yticks([-1, 0, 1])
    ax1.set_ylim([-1.1, 1.1])
    ymin, ymax = ax1.get_yaxis().get_view_interval()
    xmin, xmax = ax1.get_xaxis().get_view_interval()
    ax1.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black',
                                 linewidth=2))
    ax1.set_xticks([])
    ax1.yaxis.set_ticks_position('left')
    ax1.xaxis.set_ticks_position('bottom')
    ax1.tick_params(axis='both', which='major', labelsize=20)
    ax1 = fig.add_axes([.55, 0.57, .4, .1])
    ax1.set_frame_on(False)
    ax1.fill_between(range(0, 500), theta_fit[:, 1] - 2.58 *
                     numpy.sqrt(sigma_fit[:, 1]),
                     theta_fit[:, 1] + 2.58 * numpy.sqrt(sigma_fit[:, 1]),
                     color=[.5, .5, .5])
    ax1.plot(range(500), theta[:, 1], linewidth=4, color='k')
    ax1.set_yticks([-1, 0, 1])
    ax1.set_ylim([-1.1, 1.5])
    ymin, ymax = ax1.get_yaxis().get_view_interval()
    xmin, xmax = ax1.get_xaxis().get_view_interval()
    ax1.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black',
                                 linewidth=2))
    ax1.set_xticks([])
    ax1.yaxis.set_ticks_position('left')
    ax1.xaxis.set_ticks_position('bottom')
    ax1.set_ylabel('$\\theta_{ij}$', fontsize=26)
    ax1.tick_params(axis='both', which='major', labelsize=20)
    ax1 = fig.add_axes([.55, 0.46, .4, .1])
    ax1.set_frame_on(False)
    ax1.fill_between(range(0, 500),
                     theta_fit[:, 2] - 2.58 * numpy.sqrt(sigma_fit[:, 2]),
                     theta_fit[:, 2] + 2.58 * numpy.sqrt(sigma_fit[:, 2]),
                     color=[.6, .6, .6])
    ax1.plot(range(500), theta[:, 2], linewidth=4, color='k')
    ax1.set_ylim([-1.1, 1.1])
    ymin, ymax = ax1.get_yaxis().get_view_interval()
    xmin, xmax = ax1.get_xaxis().get_view_interval()
    ax1.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black',
                                 linewidth=2))
    ax1.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black',
                                 linewidth=3))
    ax1.set_xticks([50, 150, 300])
    ax1.yaxis.set_ticks_position('left')
    ax1.xaxis.set_ticks_position('bottom')
    ax1.set_xlabel('Time [AU]', fontsize=26)
    ax1.set_yticks([-1, 0, 1])
    ax1.tick_params(axis='both', which='major', labelsize=20)

    # Figure C
    psi_color = numpy.array([51, 153., 255]) / 256.
    eta_color = numpy.array([0, 204., 102]) / 256.
    S_color = numpy.array([255, 162, 0]) / 256.
    C_color = numpy.array([204, 60, 60]) / 256.
    psi_quantiles = mquantiles(f['sampled_results']['psi_sampled'][:, :, 1],
                               prob=[.01, .99], axis=1)
    psi_true = f['data']['psi_all'].value
    eta_quantiles = mquantiles(numpy.mean(
        f['sampled_results']['eta_sampled'][:, :N, :], axis=1), prob=[.01, .99],
                               axis=1)
    C_quantiles = mquantiles(f['sampled_results']['C_sampled'][:, :],
                             prob=[.01, .99], axis=1)
    C_true = f['data']['C_all']
    S_quantiles = mquantiles(f['sampled_results']['S_sampled'][:, :],
                             prob=[.01, .99], axis=1)
    S_true = f['data']['S_all']
    eta1 = numpy.empty(f['data']['theta1'].shape)
    eta2 = numpy.empty(f['data']['theta2'].shape)
    T = eta1.shape[0]
    N1, N2 = 15, 15
    transforms.initialise(N1, O)
    for i in range(T):
        p = transforms.compute_p(f['data']['theta1'][i])
        eta1[i] = transforms.compute_eta(p)
        p = transforms.compute_p(f['data']['theta2'][i])
        eta2[i] = transforms.compute_eta(p)

    ax1 = fig.add_axes([.08, 0.23, .4, .15])
    ax1.set_frame_on(False)
    ax1.fill_between(range(0, 500), eta_quantiles[:, 0], eta_quantiles[:, 1],
                     color=eta_color)
    eta_true = 1. / 2. * (numpy.mean(eta1[:, :N1], axis=1) +
                          numpy.mean(eta2[:, :N2], axis=1))
    ax1.fill_between(range(0, 500), eta_quantiles[:, 0], eta_quantiles[:, 1],
                     color=eta_color)
    ax1.plot(range(500), eta_true, linewidth=4, color=eta_color * .8)

    ax1.set_yticks([.1, .2, .3])
    ax1.set_ylim([.09, .35])
    ymin, ymax = ax1.get_yaxis().get_view_interval()
    xmin, xmax = ax1.get_xaxis().get_view_interval()
    ax1.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black',
                                 linewidth=2))
    ax1.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black',
                                 linewidth=3))
    ax1.set_xticks([50, 150, 300])
    ax1.yaxis.set_ticks_position('left')
    ax1.xaxis.set_ticks_position('bottom')
    ax1.set_ylabel('$p_{\\mathrm{spike}}$', fontsize=26)
    ax1.tick_params(axis='both', which='major', labelsize=20)

    ax1 = fig.add_axes([.08, 0.05, .4, .15])
    ax1.set_frame_on(False)
    ax1.fill_between(range(0, 500), numpy.exp(-psi_quantiles[:, 0]),
                     numpy.exp(-psi_quantiles[:, 1]), color=psi_color)
    ax1.plot(range(500), numpy.exp(-psi_true), linewidth=4, color=psi_color * .8)
    ax1.set_yticks([.0, .01, .02])
    ax1.set_ylim([.0, .025])
    ymin, ymax = ax1.get_yaxis().get_view_interval()
    xmin, xmax = ax1.get_xaxis().get_view_interval()
    ax1.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black',
                                 linewidth=2))
    ax1.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black',
                                 linewidth=3))
    ax1.set_xticks([50, 150, 300])
    ax1.yaxis.set_ticks_position('left')
    ax1.xaxis.set_ticks_position('bottom')
    ax1.set_ylabel('$p_{\\mathrm{silence}}$', fontsize=26)
    ax1.set_xlabel('Time [AU]', fontsize=26)
    ax1.tick_params(axis='both', which='major', labelsize=20)
    # Entropy
    ax2 = fig.add_axes([.52, 0.23, .4, .15])
    ax2.set_frame_on(False)

    ax2.fill_between(range(0, 500), S_quantiles[:, 0] / numpy.log2(numpy.exp(1)),
                     S_quantiles[:, 1] / numpy.log2(numpy.exp(1)), color=S_color)
    ax2.plot(range(500), S_true / numpy.log2(numpy.exp(1)), linewidth=4, color=S_color * .8)
    ax2.set_xticks([50, 150, 300])
    ax2.set_yticks([10, 14, 18])
    ymin, ymax = ax2.get_yaxis().get_view_interval()
    xmin, xmax = ax2.get_xaxis().get_view_interval()
    ax2.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black',
                                 linewidth=2))
    ax2.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black',
                                 linewidth=3))
    ax2.yaxis.set_ticks_position('left')
    ax2.xaxis.set_ticks_position('bottom')
    ax2.set_ylabel('$S$', fontsize=26)
    ax2.tick_params(axis='both', which='major', labelsize=20)
    # Heat capacity
    ax2 = fig.add_axes([.52, 0.05, .4, .15])
    ax2.set_frame_on(False)
    ax2.fill_between(range(0, 500),
                     C_quantiles[:, 0] / numpy.log2(numpy.exp(1)),
                     C_quantiles[:, 1] / numpy.log2(numpy.exp(1)),
                     color=C_color)
    ax2.plot(range(500), C_true / numpy.log2(numpy.exp(1)), linewidth=5,
             color=C_color * .8)
    ymin, ymax = ax2.get_yaxis().get_view_interval()
    xmin, xmax = ax2.get_xaxis().get_view_interval()
    ax2.add_artist(pyplot.Line2D((xmin, xmin), (ymin, ymax), color='black',
                                 linewidth=2))
    ax2.add_artist(pyplot.Line2D((xmin, xmax), (ymin, ymin), color='black',
                                 linewidth=3))
    ax2.set_xticks([50, 150, 300])
    ax2.set_yticks([5, 10])
    ax2.yaxis.set_ticks_position('left')
    ax2.xaxis.set_ticks_position('bottom')
    ax2.set_xlabel('Time [AU]', fontsize=26)
    ax2.set_ylabel('$C$', fontsize=26)
    ax2.tick_params(axis='both', which='major', labelsize=20)
    ax = fig.add_axes([0.03, 0.95, .05, .05], frameon=0)
    ax.set_yticks([])
    ax.set_xticks([])
    ax.text(.0, .0, 'A', fontsize=26, fontweight='bold')
    ax = fig.add_axes([0.52, 0.95, .05, .05], frameon=0)
    ax.set_yticks([])
    ax.set_xticks([])
    ax.text(.0, .0, 'B', fontsize=26, fontweight='bold')
    ax = fig.add_axes([0.05, 0.4, .05, .05], frameon=0)
    ax.set_yticks([])
    ax.set_xticks([])
    ax.text(.0, .0, 'C', fontsize=26, fontweight='bold')
    fig.savefig(plot_path+'fig1.eps')
    pyplot.show()
Beispiel #16
0
O = 2


# ----- SPIKE SYNTHESIS -----
# Global module
import numpy
# Local modules
import synthesis
import transforms

# Create underlying time-varying theta parameters as Gaussian processes
# Create mean vector
theta = synthesis.generate_thetas(N, O, T)

# Initialise the transforms library in preparation for computing P
transforms.initialise(N, O)
# Compute P for each time step
p = numpy.zeros((T, 2**N))
for i in range(T):
    p[i,:] = transforms.compute_p(theta[i,:])
# Generate spikes!
spikes = synthesis.generate_spikes(p, R, seed=1)


# ----- ALGORITHM EXECUTION -----
# Global module
import numpy
# Local module
import __init__ # From outside this folder, this would be 'import ssll'

# Run the algorithm!
Beispiel #17
0
def run(spikes, D, map_function='nr',  window=1, lmbda=100, max_iter=1, \
        J=None, theta_o=0, sigma_o=0.1, exact=True):
    """
    Master-function of the State-Space Analysis of Spike Correlation package.
    Uses the expectation-maximisation algorithm to find the probability
    distributions of natural parameters of spike-train interactions over time.
    Calls slave functions to perform the expectation and maximisation steps
    repeatedly until the data likelihood reaches an asymptotic value.
    :param numpy.ndarray spikes:
        Binary matrix with dimensions (time, runs, cells), in which a `1' in
        location (t, r, c) denotes a spike at time t in run r by cell c.
    :param int window:
        Bin-width for counting spikes, in milliseconds.
    :param int D
        Number of graphs to consider in the fitting of the model
    :param string map_function:
        Name of the function to use for maximum a-posterior estimation of the
        natural parameters at each timestep. Refer to max_posterior.py.
    :param float lmdbda:
        Coefficient on the identity matrix of the initial state-transition
        covariance matrix.
    :param int max_iter:
        Maximum number of iterations for which to run the EM algorithm.
    :param numpy.ndarray J:
        First estimate of the graphs matrix
    :param boolean exact
        Whether to use exact solution (True) or Bethe approximation (False)
        for the computation of eta, psi and fisher info
    :returns:
        Results encapsulated in a container.EMData object, containing the
        smoothed posterior probability distributions of the natural parameters
        of the spike-train interactions at each timestep, conditional upon the
        given spikes.
    """
    # Ensure NaNs are caught
    numpy.seterr(invalid='raise', under='raise')
    # Order of interactions considered in the graphs
    order_ising = 2
    # Initialise the EM-data container
    map_func = max_posterior.functions[map_function]
    emd = container.EMData(spikes, window, D, map_func, lmbda, J,\
                           theta_o, sigma_o, exact)
    # Initialise the coordinate-transform maps
    if exact:
        transforms.initialise(emd.N, order_ising)
    # Set up loop guards for the EM algorithm
    lmp = -numpy.inf
    lmc = probability.log_marginal(emd)
    # Iterate the EM algorithm until convergence or failure
    while (emd.iterations < max_iter) and (emd.convergence >
                                           exp_max.CONVERGED):
        # Perform EM
        exp_max.e_step(emd)
        exp_max.m_step(emd)
        # Update previous and current log marginal values
        lmp = lmc
        lmc = probability.log_marginal(emd)
        # Update EM algorithm metadata
        emd.iterations += 1
        emd.convergence = abs(1 - lmp / lmc)

    return emd