def ais_latent_network_given_A(x0, graph_model, graph_sampler, N_samples=1000, B=100,
                               steps_per_B=11):
    """
    Use AIS to approximate the marginal likelihood of a latent network model
    """
    import pdb; pdb.set_trace()
    betas = np.linspace(0,1,B)

    # Sample m points
    log_weights = np.zeros(N_samples)
    for m in range(N_samples):
        # Sample a new set of graph parameters from the prior
        x = copy.deepcopy(x0)

        # print "M: %d" % m
        # Sample mus from each of the intermediate distributions,
        # starting with a draw from the prior.
        samples = []

        # Ratios correspond to the 'f_{n-1}(x_{n-1})/f_{n}(x_{n-1})' values in Neal's paper
        ratios = np.zeros(B-1)

        # Sample the intermediate distributions
        for (n,beta) in zip(range(1,B), betas[1:]):
            # print "M: %d\tBeta: %.3f" % (m,beta)
            sys.stdout.write("M: %d\tBeta: %.3f \r" % (m,beta))
            sys.stdout.flush()
            # Set the likelihood scale (beta) in the graph model
            graph_model.lkhd_scale.set_value(beta)

            # Take 100 steps per beta
            for s in range(steps_per_B):
                x = graph_sampler.update(x)

            # Compute the ratio of this sample under this distribution and the previous distribution
            curr_lkhd = seval(graph_model.log_p,
                              graph_model.get_variables(),
                              x['net']['graph'])

            graph_model.lkhd_scale.set_value(betas[n-1])
            prev_lkhd = seval(graph_model.log_p,
                              graph_model.get_variables(),
                              x['net']['graph'])

            ratios[n-1] = curr_lkhd - prev_lkhd

        # Compute the log weight of this sample
        log_weights[m] = np.sum(ratios)

        print ""
        print "W: %f" % log_weights[m]

    # Compute the mean of the weights to get an estimate of the normalization constant
    log_Z = -np.log(N_samples) + logsumexp(log_weights)
    return log_Z
Example #2
0
    def plot(self,
             xs,
             ax=None,
             title=None,
             vmin=None,
             vmax=None,
             cmap=rwb_cmap):

        # Ensure sample is a list
        if not isinstance(xs, list):
            xs = [xs]

        # Get the weight matrix and adjacency matrix
        wvars = self.population.network.weights.get_variables()
        Ws = np.array([
            seval(self.population.network.weights.W, wvars,
                  x['net']['weights']) for x in xs
        ])

        gvars = self.population.network.graph.get_variables()
        As = np.array([
            seval(self.population.network.graph.A, gvars, x['net']['graph'])
            for x in xs
        ])

        # Compute the effective connectivity matrix
        W_inf = np.mean(Ws * As, axis=0)

        # Make sure bounds are set
        if None in (vmax, vmin):
            vmax = np.amax(np.abs(W_inf))
            vmin = -vmax

        # Create a figure if necessary
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(1, 1, 1)

        px_per_node = 10
        im = ax.imshow(np.kron(W_inf, np.ones((px_per_node, px_per_node))),
                       vmin=vmin,
                       vmax=vmax,
                       extent=[0, 1, 0, 1],
                       interpolation='nearest',
                       cmap=cmap)

        ax.set_title(title)
Example #3
0
    def compute_log_prior(self, vars):
        """ Compute the log joint probability under a given set of variables
        """
        lp = 0.0

        # Get set of symbolic variables
        syms = self.get_variables()

        lp += seval(self.latent.log_p, syms['latent'], vars['latent'])

        lp += seval(self.network.log_p, syms['net'], vars['net'])

        for n in range(self.N):
            nvars = self.extract_vars(vars, n)
            lp += seval(self.glm.log_prior, syms, nvars)

        return lp
Example #4
0
 def _eval_state_helper(self, syms, d, vars):
     """ Helper function to recursively evaluate state variables
     """
     state = {}
     for (k,v) in d.items():
         if isinstance(v,dict):
             state[k] = self._eval_state_helper(syms, v, vars)
         else:
             state[k] = seval(v, syms, vars)
     return state
Example #5
0
 def _eval_state_helper(self, syms, d, vars):
     """ Helper function to recursively evaluate state variables
     """
     state = {}
     for (k, v) in d.items():
         if isinstance(v, dict):
             state[k] = self._eval_state_helper(syms, v, vars)
         else:
             state[k] = seval(v, syms, vars)
     return state
Example #6
0
    def grad_nlp(x_glm_vec, x):
        """
        Helper function to compute the gradient of negative log posterior for
        a given set of GLM parameters. The parameters are passed in as a vector.
        """
        x_glm = unpackdict(x_glm_vec, glm_shapes)
        set_vars(glm_syms, x['glm'], x_glm)
        glp = seval(g_glm_logprior,
                   syms,
                   x)

        # Add the likelihood of each data sequence
        for data in population.data_sequences:
            # Set the data
            population.set_data(data)
            glp += seval(g_glm_ll,
                        syms,
                        x)

        return -1.0 * glp
Example #7
0
    def plot(self, xs, ax=None, name='location_provider', color='k'):
        """
        Plot a histogram of the inferred locations for each neuron
        """
        # Ensure sample is a list
        if not isinstance(xs, list):
            xs = [xs]

        if name not in xs[0]['latent']:
            return

        # Get the locations
        loccomp = self.population.latent.latentdict[name]
        locprior = loccomp.location_prior
        locvars = loccomp.get_variables()
        Ls = np.array([seval(loccomp.Lmatrix,
                            locvars, x['latent'][name])
                       for x in xs])
        [N_smpls, N, D] = Ls.shape

        for n in range(N):
            # plt.subplot(1,N,n+1, aspect=1.0)
            # plt.title('N: %d' % n)

            if N_smpls == 1:
                if D == 1:
                    plt.plot([Ls[0,n,0], Ls[0,n,0]],
                             [0,2], color=color, lw=2)
                elif D == 2:
                    ax.plot(Ls[0,n,1], Ls[0,n,0], 's',
                             color=color, markerfacecolor=color)
                    ax.text(Ls[0,n,1]+0.25, Ls[0,n,0]+0.25, '%d' % n,
                             color=color)

                    # Set the limits
                    ax.set_xlim((locprior.min0-0.5, locprior.max0+0.5))
                    ax.set_ylim((locprior.max1+0.5, locprior.min1-0.5))
                else:
                    raise Exception("Only plotting locs of dim <= 2")
            else:
                # Plot a histogram of samples
                if D == 1:
                    ax.hist(Ls[:,n,0], bins=20, normed=True, color=color)
                elif D == 2:
                    ax.hist2d(Ls[:,n,1], Ls[:,n,0], bins=np.arange(-0.5,5), cmap='Reds', alpha=0.5, normed=True)

                    # Set the limits
                    ax.set_xlim((locprior.min0-0.5, locprior.max0+0.5))
                    ax.set_ylim((locprior.max1+0.5, locprior.min1-0.5))

                    # ax.colorbar()
                else:
                    raise Exception("Only plotting locs of dim <= 2")
Example #8
0
    def grad_nlp(x_glm_vec, x):
        """
        Helper function to compute the gradient of negative log posterior for
        a given set of GLM parameters. The parameters are passed in as a vector.
        """
        x_network = unpackdict(x_glm_vec, network_shapes)
        set_vars(network_syms, x['net'], x_network)
        glp = seval(g_network_logprior,
                   syms,
                   x)

        return -1.0 * glp
Example #9
0
    def plot(self, xs, ax=None, title=None, vmin=None, vmax=None, cmap=rwb_cmap):

        # Ensure sample is a list
        if not isinstance(xs, list):
            xs = [xs]

        # Get the weight matrix and adjacency matrix
        wvars = self.population.network.weights.get_variables()
        Ws = np.array([seval(self.population.network.weights.W,
                            wvars, x['net']['weights'])
                       for x in xs])

        gvars = self.population.network.graph.get_variables()
        As = np.array([seval(self.population.network.graph.A,
                             gvars,  x['net']['graph'])
                       for x in xs])


        # Compute the effective connectivity matrix
        W_inf = np.mean(Ws*As, axis=0)

        # Make sure bounds are set
        if None in (vmax,vmin):
            vmax = np.amax(np.abs(W_inf))
            vmin = -vmax

        # Create a figure if necessary
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(1,1,1)

        px_per_node = 10
        im = ax.imshow(np.kron(W_inf,np.ones((px_per_node,px_per_node))),
                       vmin=vmin, vmax=vmax,
                       extent=[0,1,0,1],
                       interpolation='nearest',
                       cmap=cmap)

        ax.set_title(title)
Example #10
0
    def compute_log_prior(self, vars):
        """ Compute the log joint probability under a given set of variables
        """
        lp = 0.0

        # Get set of symbolic variables
        syms = self.get_variables()

        lp += seval(self.latent.log_p,
                    syms['latent'],
                    vars['latent'])

        lp += seval(self.network.log_p,
                    syms['net'],
                    vars['net'])

        for n in range(self.N):
            nvars = self.extract_vars(vars, n)
            lp += seval(self.glm.log_prior,
                        syms,
                        nvars)

        return lp
Example #11
0
    def compute_ll(self, vars):
        """ Compute the log likelihood under a given set of variables
        """
        ll = 0.0

        # Get set of symbolic variables
        syms = self.get_variables()

        # Add the likelihood from each GLM
        for n in range(self.N):
            nvars = self.extract_vars(vars, n)
            ll += seval(self.glm.ll, syms, nvars)

        return ll
Example #12
0
    def compute_ll(self, vars):
        """ Compute the log likelihood under a given set of variables
        """
        ll = 0.0

        # Get set of symbolic variables
        syms = self.get_variables()

        # Add the likelihood from each GLM
        for n in range(self.N):
            nvars = self.extract_vars(vars, n)
            ll += seval(self.glm.ll,
                        syms,
                        nvars)

        return ll
Example #13
0
class Population:
    """
    Population connected GLMs.
    """
    def __init__(self, model):
        """
        Initialize the population of GLMs connected by a network.
        """
        self.model = model
        self.N = model['N']

        # Initialize a list of data sequences
        self.data_sequences = []

        # Initialize latent variables of the population
        self.latent = LatentVariables(model)

        # Create a network model to connect the GLMs
        self.network = Network(model, self.latent)

        # Create a single GLM that is shared across neurons
        # This is to simplify the model and reuse parameters.
        # Basically it speeds up the gradient calculations since we
        # can manually leverage conditional independencies among GLMs
        self.glm = Glm(model, self.network, self.latent)

    def compute_log_p(self, vars):
        """ Compute the log joint probability under a given set of variables
        """
        lp = 0.0
        lp += self.compute_log_prior(vars)

        # Add the likelihood of each data sequence
        for data in self.data_sequences:
            self.set_data(data)
            lp += self.compute_ll(vars)

        return lp

    def compute_log_prior(self, vars):
        """ Compute the log joint probability under a given set of variables
        """
        lp = 0.0

        # Get set of symbolic variables
        syms = self.get_variables()

        lp += seval(self.latent.log_p, syms['latent'], vars['latent'])

        lp += seval(self.network.log_p, syms['net'], vars['net'])

        for n in range(self.N):
            nvars = self.extract_vars(vars, n)
            lp += seval(self.glm.log_prior, syms, nvars)

        return lp

    def compute_ll(self, vars):
        """ Compute the log likelihood under a given set of variables
        """
        ll = 0.0

        # Get set of symbolic variables
        syms = self.get_variables()

        # Add the likelihood from each GLM
        for n in range(self.N):
            nvars = self.extract_vars(vars, n)
            ll += seval(self.glm.ll, syms, nvars)

        return ll

    def eval_state(self, vars):
        """ Evaluate the population state expressions given the parameters, 
            e.g. the stimulus response curves from the basis function weights.
        """
        # Get set of symbolic variables
        syms = self.get_variables()

        # Get the symbolic state expression to evaluate
        state_vars = self.get_state()
        state = {}

        state['latent'] = self._eval_state_helper(syms['latent'],
                                                  state_vars['latent'],
                                                  vars['latent'])

        state['net'] = self._eval_state_helper(syms['net'], state_vars['net'],
                                               vars['net'])

        glm_states = []
        for n in np.arange(self.N):
            nvars = self.extract_vars(vars, n)
            glm_states.append(
                self._eval_state_helper(syms, state_vars['glm'], nvars))
        state['glms'] = glm_states

        # Finally, evaluate the log probability and the log likelihood
        # state['logp'] = self.compute_log_p(vars)
        state['logprior'] = self.compute_log_prior(vars)
        state['ll'] = self.compute_ll(vars)
        state['logp'] = state['ll'] + state['logprior']
        return state

    def _eval_state_helper(self, syms, d, vars):
        """ Helper function to recursively evaluate state variables
        """
        state = {}
        for (k, v) in d.items():
            if isinstance(v, dict):
                state[k] = self._eval_state_helper(syms, v, vars)
            else:
                state[k] = seval(v, syms, vars)
        return state

    def get_variables(self):
        """ Get a list of all variables
        """
        v = {}
        v['latent'] = self.latent.get_variables()
        v['net'] = self.network.get_variables()
        v['glm'] = self.glm.get_variables()
        return v

    def set_hyperparameters(self, model):
        """ Set the hyperparameters of the model
        """
        self.latent.set_hyperparameters(model)
        self.network.set_hyperparameters(model)
        self.glm.set_hyperparameters(model)

    def sample(self):
        """
        Sample parameters of the GLM from the prior
        """
        v = {}
        v['latent'] = self.latent.sample(v)
        v['net'] = self.network.sample(v)
        v['glms'] = []
        for n in range(self.N):
            xn = self.glm.sample(v)
            xn['n'] = n
            v['glms'].append(xn)

        return v

    def extract_vars(self, vals, n):
        """ Hacky helper function to extract the variables for only the
         n-th GLM.s
        """

        newvals = {}
        for (k, v) in vals.items():
            if k == 'glms':
                newvals['glm'] = v[n]
            else:
                newvals[k] = v
        return newvals

    def get_state(self):
        """ Get the 'state' of the system in symbolic Theano variables
        """
        state = {}
        state['latent'] = self.latent.get_state()
        state['net'] = self.network.get_state()
        state['glm'] = self.glm.get_state()

        return state

    def preprocess_data(self, data):
        """
        Preprocess the data to compute filtered stimuli, spike trains, etc.
        """
        assert isinstance(data, dict), 'Data must be a dictionary'
        self.latent.preprocess_data(data)
        self.network.preprocess_data(data)
        self.glm.preprocess_data(data)
        data['preprocessed'] = True
        return data

    def add_data(self, data, set_as_current_data=True):
        """
        Add another data sequence to the population. Recursively call components
        to prepare the new data sequence. E.g. the background model may preprocess
        the stimulus with a set of basis filters.
        """
        # TODO: Figure out how to handle time varying weights with multiple
        # data sequences. Maybe we only allow one sequence.

        assert isinstance(data, dict), 'Data must be a dictionary'

        # Check for spike times in the data array
        assert 'S' in data, 'Data must contain an array of spike times'
        assert isinstance(data['S'],
                          np.ndarray), 'Spike times must be a numpy array'

        if 'preprocessed' not in data or not data['preprocessed']:
            data = self.preprocess_data(data)

        # Add the data to the list
        self.data_sequences.append(data)

        # By default, we set this as the current dataset
        if set_as_current_data:
            self.set_data(data)

    def set_data(self, data):
        """
        Condition on the data
        """
        assert 'preprocessed' in data and data['preprocessed'] == True, \
            'Data must be preprocessed before it can be set'
        self.latent.set_data(data)
        self.network.set_data(data)
        self.glm.set_data(data)

    def simulate(self, vars, (T_start, T_stop), dt, stim, dt_stim):
        """ Simulate spikes from a network of coupled GLMs
        :param vars - the variables corresponding to each GLM
        :type vars    list of N variable vectors
        :param dt    - time steps to simulate

        :rtype TxN matrix of spike counts in each bin
        """
        # Initialize the background rates
        N = self.model['N']
        t = np.arange(T_start, T_stop, dt)
        t_ind = np.arange(int(T_start / dt), int(T_stop / dt))
        assert len(t) == len(t_ind)
        nT = len(t)

        # Get set of symbolic variables
        syms = self.get_variables()

        # Initialize the background rate
        X = np.zeros((nT, N))
        for n in np.arange(N):
            nvars = self.extract_vars(vars, n)
            X[:, n] = seval(self.glm.bias_model.I_bias, syms, nvars)

        # Add stimulus induced currents if given
        temp_data = {'S': np.zeros((nT, N)), 'stim': stim, 'dt_stim': dt_stim}
        self.add_data(temp_data)
        for n in np.arange(N):
            nvars = self.extract_vars(vars, n)
            X[:, n] += seval(self.glm.bkgd_model.I_stim, syms, nvars)
        print "Max background rate: %s" % str(
            self.glm.nlin_model.f_nlin(np.amax(X)))

        # Remove the temp data from the population data sequences
        self.data_sequences.pop()

        # Get the impulse response functions
        imps = []
        for n_post in np.arange(N):
            nvars = self.extract_vars(vars, n_post)
            imps.append(seval(self.glm.imp_model.impulse, syms, nvars))
        imps = np.transpose(np.array(imps), axes=[1, 0, 2])
        T_imp = imps.shape[2]

        # Debug: compute effective weights
        # tt_imp = dt*np.arange(T_imp)
        # Weff = np.trapz(imps, tt_imp, axis=2)
        # print "Effective impulse weights: "
        # print Weff

        # Iterate over each time step and generate spikes
        S = np.zeros((nT, N))
        acc = np.zeros(N)
        thr = -np.log(np.random.rand(N))

        # TODO: Handle time-varying weights appropriately
        time_varying_weights = False
        if not time_varying_weights:
            At = np.tile(
                np.reshape(
                    seval(self.network.graph.A, syms['net'], vars['net']),
                    [N, N, 1]), [1, 1, T_imp])

            Wt = np.tile(
                np.reshape(
                    seval(self.network.weights.W, syms['net'], vars['net']),
                    [N, N, 1]), [1, 1, T_imp])

        # Count the number of exceptions arising from more spikes per bin than allowable
        n_exceptions = 0
        for t in np.arange(nT):
            # Update accumulator
            if np.mod(t, 10000) == 0:
                print "Iteration %d" % t
            # TODO Handle nonlinearities with variables
            lam = self.glm.nlin_model.f_nlin(X[t, :])
            acc = acc + lam * dt

            # Spike if accumulator exceeds threshold
            i_spk = acc > thr
            S[t, i_spk] += 1
            n_spk = np.sum(i_spk)

            # Compute the length of the impulse response
            t_imp = np.minimum(nT - t - 1, T_imp)

            # Get the instantaneous connectivity
            if time_varying_weights:
                # TODO: Really get the time-varying weights
                At = np.tile(
                    np.reshape(
                        seval(self.network.graph.A, syms['net'], vars['net']),
                        [N, N, 1]), [1, 1, t_imp])

                Wt = np.tile(
                    np.reshape(
                        seval(self.network.weights.W, syms['net'],
                              vars['net']), [N, N, 1]), [1, 1, t_imp])

            # Iterate until no more spikes
            # Cap the number of spikes in a time bin
            max_spks_per_bin = 10
            while n_spk > 0:
                if np.any(S[t, :] >= max_spks_per_bin):
                    n_exceptions += 1
                    break
                # Add weighted impulse response to activation of other neurons)
                X[t + 1:t + t_imp + 1, :] += np.sum(
                    At[i_spk, :, :t_imp] * Wt[i_spk, :, :t_imp] *
                    imps[i_spk, :, :t_imp], 0).T

                # Subtract threshold from the accumulator
                acc -= thr * i_spk
                acc[acc < 0] = 0

                # Set new threshold after spike
                thr[i_spk] = -np.log(np.random.rand(n_spk))

                i_spk = acc > thr
                S[t, i_spk] += 1
                n_spk = np.sum(i_spk)

                #if np.any(S[t,:]>10):
                #    import pdb
                #    pdb.set_trace()
                #    raise Exception("More than 10 spikes in a bin! Decrease variance on impulse weights or decrease simulation bin width.")

        # DEBUG:
        tt = dt * np.arange(nT)
        lam = np.zeros_like(X)
        for n in np.arange(N):
            lam[:, n] = self.glm.nlin_model.f_nlin(X[:, n])

        print "Max firing rate (post sim): %f" % np.max(lam)
        E_nS = np.trapz(lam, tt, axis=0)
        nS = np.sum(S, 0)

        print "Sampled %s spikes." % str(nS)
        print "Expected %s spikes." % str(E_nS)

        if np.any(np.abs(nS - E_nS) > 3 * np.sqrt(E_nS)):
            print "ERROR: Actual num spikes (%s) differs from expected (%s) by >3 std." % (
                str(nS), str(E_nS))

        print "Number of exceptions arising from multiple spikes per bin: %d" % n_exceptions

        return S, X
Example #14
0
    def plot(self, xs, ax=None, name='location_provider', color='k'):
        """
        Plot a histogram of the inferred locations for each neuron
        """
        # Ensure sample is a list
        if not isinstance(xs, list):
            xs = [xs]

        if name not in xs[0]['latent']:
            return

        # Get the locations
        loccomp = self.population.latent.latentdict[name]
        locprior = loccomp.location_prior
        locvars = loccomp.get_variables()
        Ls = np.array(
            [seval(loccomp.Lmatrix, locvars, x['latent'][name]) for x in xs])
        [N_smpls, N, D] = Ls.shape

        for n in range(N):
            # plt.subplot(1,N,n+1, aspect=1.0)
            # plt.title('N: %d' % n)

            if N_smpls == 1:
                if D == 1:
                    plt.plot([Ls[0, n, 0], Ls[0, n, 0]], [0, 2],
                             color=color,
                             lw=2)
                elif D == 2:
                    ax.plot(Ls[0, n, 1],
                            Ls[0, n, 0],
                            's',
                            color=color,
                            markerfacecolor=color)
                    ax.text(Ls[0, n, 1] + 0.25,
                            Ls[0, n, 0] + 0.25,
                            '%d' % n,
                            color=color)

                    # Set the limits
                    ax.set_xlim((locprior.min0 - 0.5, locprior.max0 + 0.5))
                    ax.set_ylim((locprior.max1 + 0.5, locprior.min1 - 0.5))
                else:
                    raise Exception("Only plotting locs of dim <= 2")
            else:
                # Plot a histogram of samples
                if D == 1:
                    ax.hist(Ls[:, n, 0], bins=20, normed=True, color=color)
                elif D == 2:
                    ax.hist2d(Ls[:, n, 1],
                              Ls[:, n, 0],
                              bins=np.arange(-0.5, 5),
                              cmap='Reds',
                              alpha=0.5,
                              normed=True)

                    # Set the limits
                    ax.set_xlim((locprior.min0 - 0.5, locprior.max0 + 0.5))
                    ax.set_ylim((locprior.max1 + 0.5, locprior.min1 - 0.5))

                    # ax.colorbar()
                else:
                    raise Exception("Only plotting locs of dim <= 2")