def calc_stca(spikes, stimulus, filter_length):
    rw = asc.rolling_window(stimulus, filter_length, preserve_dim=True)
    sta = (spikes @ rw) / spikes.sum()
    # STA is not projected out like Equation 4 in Schwartz et al.2006,J.Vision
    precovar = (rw * spikes[:, None]) - sta
    stc = (precovar.T @ precovar) / (spikes.sum() - 1)
    return sta, stc
Exemple #2
0
def calc_stca(spikes, stimulus, filter_length):
    """
    Calculate spike triggered average and spike triggered covariance
    from a 1D stimulus vector and the corresponding spikes.
    """
    rw = asc.rolling_window(stimulus, filter_length, preserve_dim=True)
    sta, stc = calc_stca_from_stimulus_matrix(spikes, rw)
    return sta, stc
Exemple #3
0
def packdims(array, window):
    sh = array.shape
    if array.ndim == 1:
        array = array[None, :]
    if array.ndim > 2:
        array = array.reshape(np.prod(sh[:-1]), -1)
    rw = np.empty((sh[-1], array.shape[0] * window))
    for i in range(array.shape[0]):
        rw[:, i * window:(i + 1) * window] = asc.rolling_window(
            array[i, :], window)
    return rw
Exemple #4
0
 def grad(kmu):
     k_ = np.array(kmu[:-1])
     mu_ = kmu[-1]
     nlt_in = (conv(k_, x) + mu_)
     xr = asc.rolling_window(x, k_.shape[0])[:, ::-1]
     dldk = spikes @ xr - time_res * np.exp(nlt_in) @ xr
     #        dldk2 = np.zeros(l)
     #        for i in range(len(spikes)):
     #            dldk2 += spikes[i] * xr[i, :]
     #            dldk2 -= time_res*np.exp(nlt_in[i])*xr[i, :]
     #        assert np.isclose(dldk, dldk2).all()
     #        import pdb; pdb.set_trace()
     dldm = spikes.sum() - time_res * np.exp(nlt_in).sum()
     dl = -np.array([*dldk, dldm])
     return dl
Exemple #5
0
def conv2d(Q, x, optimize='greedy'):
    """
    Calculate the quadratic form for each time bin for generalized quadratic
    model.

    Uses
    * rolling window to reduce used memory
    * np.broadcast_to for shaping the quadratic filter matrix in the required
      form without allocating memory
    """
    l = Q.shape[0]
    # Generate a rolling view of the stimulus wihtout allocating space in memory
    # Equivalent to "xr = hankel(x)[:, :l]" but much more memory efficient
    xr = asc.rolling_window(x, l)[:, ::-1]
    # Stack copies of Q along a new axis without copying in memory.
    Qb = np.broadcast_to(Q, (x.shape[0], *Q.shape))
    return np.einsum('ij,ijk,ki->i', xr, Qb, xr.T, optimize=optimize)
Exemple #6
0
    checkerstimnr = 6
    maxframes = 20000

    st = OMB(exp, ombstimnr, maxframes=maxframes)
    st.clusterstats()
    #import time; time.sleep(2)
    #st.playstimulus(frames=6, pause_duration=None)

    #%%
    from datetime import datetime
    import miscfuncs as msc
    startime = datetime.now()
    contrast = st.generatecontrast(st.texpars.noiselim / 2, 100, 19)
    contrast_avg = contrast.mean(axis=-1)

    rw = asc.rolling_window(contrast, st.filter_length, preserve_dim=False)

    all_spikes = np.zeros((st.nclusters, st.ntotal))
    for i in range(st.nclusters):
        all_spikes[i, :] = st.binnedspiketimes(i)

    stas = np.einsum('abcd,ec->eabd', rw, all_spikes)
    stas /= all_spikes.sum(axis=(-1))[:, np.newaxis, np.newaxis, np.newaxis]

    # Correct for the non-informative parts of the stimulus
    stas = stas - contrast_avg[None, ..., None]

    print(
        f'{msc.timediff(startime)} elapsed for contrast generation and STA calculation'
    )
    #%%
exp, stimnr = '20180710', 1

ff = Stimulus(exp, stimnr)
stimulus = np.array(randpy.gasdev(-1000, ff.frametimings.shape[0])[0])

st = OMB(exp, 8)
ff = st
stimulus = st.bgsteps[0, :]

allspikes = ff.allspikes()

i = 0
spikes = allspikes[i, :]
filter_length = ff.filter_length

rw = asc.rolling_window(stimulus, filter_length, preserve_dim=True)
sta = (spikes @ rw) / spikes.sum()
#%%
# I am not projecting out the STA like Equation 4 in Schwartz et al.2006,J.Vision
precovar = (rw * spikes[:, None]) - sta
stc = (precovar.T @ precovar) / (spikes.sum() - 1)
eigvals, eigvecs = np.linalg.eig(stc)
eigsort = np.argsort(eigvals)
eigvals, eigvecs = eigvals[eigsort], eigvecs[:, eigsort]

stc2 = calc_stc(spikes, stimulus, filter_length)

fig, axes = plt.subplots(2, 1)
axes[0].plot(eigvals, 'o')
ax1 = axes[1]
ax1.plot(sta)
    exp, ombstimnr = '20180710', 8
    checkerstimnr = 6
    maxframes = 20000

    st = OMB(exp, ombstimnr, maxframes=maxframes)
    st.clusterstats()
    #import time; time.sleep(2)
    #st.playstimulus(frames=6, pause_duration=None)

    #%%
    from datetime import datetime
    import miscfuncs as msc
    startime = datetime.now()
    a = st.generatecontrast(st.texpars.noiselim / 2, 100)
    # Capitalize name of variable to prevent it from slowing variable exp. down
    RW = asc.rolling_window(a, st.filter_length)

    all_spikes = np.zeros((st.nclusters, st.ntotal))
    for i in range(st.nclusters):
        all_spikes[i, :] = st.binnedspiketimes(i)

    # Add a cell that spikes at every bin to find the non-spike triggered average
    all_spikes = np.vstack((all_spikes, np.ones(all_spikes.shape[-1])))
    stas = np.einsum('abcd,ec->eabd', RW, all_spikes)
    del RW
    stas /= all_spikes.sum(axis=(-1))[:, np.newaxis, np.newaxis, np.newaxis]
    print(
        f'{msc.timediff(startime)} elapsed for contrast generation and STA calculation'
    )
    #%%
    #    fig1 = plt.figure(1)
Exemple #9
0
def minimize_loglikelihood(
        k_initial,
        #                           Q_initial,
        mu_initial,
        x,
        time_res,
        spikes,
        usegrad=True,
        debug_grad=False,
        method='CG',
        minimize_disp=False,
        **kwargs):
    """
    Calculate the filters that minimize the log likelihood function for a
    given set of spikes and stimulus.

    Parameters
    --------
    k_initial, Q_initial, mu_initial:
        Initial guesses for the parameters.
    x:
        The stimulus
    time_res:
        Length of each bin (referred also as Delta, frame_duration)
    spikes:
        Binned spikes, must have the same shape as the stimulus
    usegrad:
        Whether to use gradients for optimiziation. If set to False, only
        approximated gradients will be used with the appropriate optimization
        method.
    debug_grad:
        Whether to calculate and plot the gradients in the first iteration
        Setting it to True will change the returned values.
    method:
        Optimization method to use, see the Notes section in the  documentation of
        scipy.minimize for a full list.
    minimize_disp:
        Whether to print the convergence messages of the optimization function
    """
    kQmu_initial = flattenpars(
        k_initial,
        #                               Q_initial
        mu_initial)

    # Infer the filter length from the shape of the initial guesses and
    # set it globally so that other functions can also use it.
    global filter_length
    if filter_length is None:
        filter_length = k_initial.shape[0]

    def loglikelihood(kQmu):
        """
        Define the likelihood function for GQM
        """
        # Star before an argument expands (or unpacks) the values
        P = gqm_in(*splitpars(kQmu))
        return -np.sum(spikes * P(x)) + time_res * np.sum(np.exp(P(x)))

    # Instead of iterating over each time bin, generate a hankel matrix
    # from the stimulus vector and operate on that using matrix
    # multiplication like so: X @ xh , where X is a vector containing
    # some number for each time bin.

#    xh = hankel(x)[:, :filter_length]
# Instead of iterating over each time bin, use the rolling window function
# The expression in the brackets inverts the array.
    xr = asc.rolling_window(x, filter_length)[:, ::-1]
    # Initialize a 3D numpy array to keep outer products
    sTs = np.zeros((spikes.shape[0], filter_length, filter_length))
    for i in range(spikes.shape[0] - filter_length):
        #        x_temp = x[i:i+filter_length][np.newaxis,:]
        x_temp = xr[i, :]
        sTs[i, :, :] = np.outer(x_temp, x_temp)
    # Empirically found correction terms for the gradients.
#    k_correction = x.shape[0]*time_res*xr.sum(axis=0)
    plt.plot(np.diag(sTs.sum(axis=0)))
    plt.title('diag(sTs.sum(axis=0))')
    plt.show()

    #    import pdb; pdb.set_trace()
    #    q_correction = x.shape[0]*time_res*sTs.sum(axis=0) + np.eye(filter_length)*x.shape[0]
    #    q_correction = x.shape[0]*time_res*sTs.sum(axis=0) + np.diag(sTs.sum(axis=0))
    #    mu_correction = (x.shape[0]-1) * x.shape[0]*time_res
    def gradients(kQmu):
        """
        Calculate gradients for the log-likelihood function
        """
        k, mu = splitpars(kQmu)
        P = np.exp(gqm_in(k, mu)(x))
        #        Slow way of calculating the gradients
        #        dLdk = np.zeros(k.shape)
        #        dLdq = np.zeros(Q.shape)
        #        dLdmu = 0
        #        for i in range(filter_length, x_mini.shape[0]):
        #            s = x[i:i+filter_length]
        #            dLdk += (spikes[i] * s -
        #                       time_res*P[i]*s)
        #            dLdq += (spikes[i] * np.outer(s,s) - time_res*P[i] * np.outer(s, s))
        #            dLdmu += spikes[i] - time_res * P[i]
        # Fast way of calculating gradients using rolling window and einsum
        dLdk = spikes @ xr - time_res * (P @ xr)
        #        dLdk -= k_correction
        # Using einsum to multiply and sum along the desired axis.
        # more detailed explanation here:
        # https://stackoverflow.com/questions/26089893/understanding-numpys-einsum
        dLdq = (np.einsum('ijk,i->jk', sTs, spikes) -
                time_res * np.einsum('ijk,i->jk', sTs, P))
        #        dLdq -= q_correction
        dLdmu = spikes.sum() - time_res * np.sum(P)
        #        dLdmu -= mu_correction
        #        import pdb; pdb.set_trace()

        dL = flattenpars(dLdk, dLdmu)
        return -dL

    if debug_grad:
        # Epsilon value to use when approximating the gradient
        eps = 1e-10
        ap_grad = approx_fprime(kQmu_initial, loglikelihood, eps)
        man_grad = gradients(kQmu_initial)
        # Split the auto and manual gradients into k, q and mu
        kda, mda = splitpars(ap_grad)
        kdm, mdm = splitpars(man_grad)
        diff = ap_grad - man_grad
        k_diff, mu_diff = splitpars(diff)
        print('Gradient diff L2 norm', np.sum(diff**2))
        plt.figure(figsize=(7, 10))
        axk = plt.subplot(411)
        axk.plot(kda, label='Auto grad')
        axk.plot(kdm, label='Manual grad')
        axk.legend()
        axkdif = plt.subplot(412)
        axkdif.plot(k_diff, 'k', label='auto - manual gradient')
        axkdif.legend()
        #        axqa = plt.subplot(425)
        #        imqa = axqa.imshow(qda)
        #        axqa.set_title('Auto grad Q')
        #        plt.colorbar(imqa)
        #        axqm = plt.subplot(426)
        #        axqm.set_title('Manual grad Q')
        #        imqm = axqm.imshow(qdm)
        #        plt.colorbar(imqm)
        #        axqdif = plt.subplot(427)
        #        imqdif = axqdif.imshow(Q_diff)
        #        plt.colorbar(imqdif)
        #        plt.suptitle(f'Difference of numerical and explicit gradients, mu_diff: {mu_diff:11.2f}')
        plt.show()
        #        import pdb; pdb.set_trace();
        return kda, mda, kdm, mdm
#     If debug_grad is True, the function returns on the previous line, rest of the minimize_loglhd function
# is not executed
    minimizekwargs = {'options': {'disp': minimize_disp}}
    if usegrad:
        minimizekwargs.update({'jac': gradients})
    minimizekwargs.update(kwargs)

    res = minimize(loglikelihood,
                   kQmu_initial,
                   tol=1e-5,
                   method=method,
                   **minimizekwargs)
    return res
Exemple #10
0
checkerstimnr = 1


st = OMB(exp, ombstimnr,
         maxframes=1000
         )

choosecells = [54, 55, 108, 109]
nrcells = len(choosecells)

all_spikes = np.zeros((nrcells, st.ntotal), dtype=np.int8)

for i, cell in enumerate(choosecells):
    all_spikes[i, :] = st.binnedspiketimes(cell)

rw = asc.rolling_window(st.bgsteps, st.filter_length)

motionstas = np.einsum('abc,db->dac', rw, all_spikes)
motionstas /= all_spikes.sum(axis=(-1))[:, np.newaxis, np.newaxis]

#%% Filter the stimuli

# Euclidian norm
motionstas_norm = motionstas / np.sqrt((motionstas**2).sum(axis=-1))[:, :, None]

bgsteps = st.bgsteps / np.sqrt(st.bgsteps.var())
rw = asc.rolling_window(bgsteps, st.filter_length)


steps_proj = np.einsum('abc,bdc->ad', motionstas_norm, rw)
Exemple #11
0
def ombtexturesta(exp, ombstimnr, maxframes=10000,
                  contrast_window=100, plot=False):
    """
    Calculates the spike-triggered average for the full texture for the OMB
    stimulus. Based on the maximum intensity pixel of the STAs, calculates
    the center of the receptive field and the contrast signal for this
    pixel throughout the stimulus; to be used as input for models.

    Parameters:
    --------
        exp:
            The experiment name
        ombstimulusnr:
            Number of the OMB stimulus in the experiment
        maxframes:
            Maximum number of frames that will be used, typically the
            array containing the contrast is very large and
            it is easy to fill the RAM. Refer to OMB.generatecontrast()
            documentation.
        contrast_window:
            Number of pixels to be used for the size of the texture.
            Measured in each direction starting from the center so
            a value of 100 will yield texture with size (201, 201, N)
            where N is the total number of frames.
        plot:
            If True draws an interactive plot for browsing all STAs,
            also marking the center pixels. Requires an interactive backend

    """
    st = OMB(exp, ombstimnr, maxframes=maxframes)
    st.clusterstats()

    contrast = st.generatecontrast(st.texpars.noiselim/2,
                                   window=contrast_window,
                                   pad_length=st.filter_length-1)

    contrast_avg = contrast.mean(axis=-1)

    RW = asc.rolling_window(contrast, st.filter_length, preserve_dim=False)

    all_spikes = np.zeros((st.nclusters, st.ntotal))
    for i in range(st.nclusters):
        all_spikes[i, :] = st.binnedspiketimes(i)

    texturestas = np.einsum('abcd,ec->eabd', RW, all_spikes)

    texturestas /= all_spikes.sum(axis=(-1))[:, np.newaxis,
                                             np.newaxis, np.newaxis]

    # Correct for the non-informative parts of the stimulus
    texturestas = texturestas - contrast_avg[None, ..., None]
    #%%
    if plot:
        fig_stas, _ = plf.multistabrowser(texturestas, cmap='Greys_r')

    texture_maxi = np.zeros((st.nclusters, 2), dtype=int)
    # Take the pixel with maximum intensity for contrast signal
    for i in range(st.nclusters):
        coords = np.unravel_index(np.argmax(np.abs(texturestas[i])),
                                  texturestas[i].shape)[:-1]
        texture_maxi[i, :] = coords
        if plot:
            ax = fig_stas.axes[i]
            # Coordinates need to be inverted for display
            ax.plot(*coords[::-1], 'r+', markersize=10, alpha=0.2)
    #%%
    contrast_signals = np.empty((st.nclusters, st.ntotal))
    # Calculate the time course of the center(maximal pixel of texture STAs
    stas_center = np.zeros((st.nclusters, st.filter_length))
    for i in range(st.nclusters):
        coords = texture_maxi[i, :]
        # Calculate the contrast signal that can be used for GQM
        # Cut the extra part at the beginning that was added by generatecontrast
        contrast_signals[i, :] = contrast[coords[0], coords[1],
                                          st.filter_length-1:]
        stas_center[i] = texturestas[i, coords[0], coords[1], :]

    stas_center_norm = asc.normalize(stas_center)

    fig_contrast, axes = plt.subplots(*plf.numsubplots(st.nclusters), sharey=True)
    for i, ax in enumerate(axes.ravel()):
        if i < st.nclusters:
            ax.plot(stas_center_norm[i, :])

    savepath = os.path.join(st.exp_dir, 'data_analysis', st.stimname)
    savefname = f'{st.stimnr}_texturesta'
    if not maxframes:
        maxframes = st.ntotal
    savefname += f'_{maxframes}fr'

    plt.ylim([np.nanmin(stas_center_norm), np.nanmax(stas_center_norm)])
    fig_contrast.suptitle('Time course of center pixel of texture STAs')
    fig_contrast.savefig(os.path.join(savepath, 'texturestas.svg'))

    # Do not save the contrast signal because it is ~6GB for 20000 frames of recording
    keystosave = ['texturestas', 'contrast_avg', 'stas_center',
                  'stas_center_norm', 'contrast_signals', 'texture_maxi',
                  'maxframes', 'contrast_window']
    datadict = {}
    for key in keystosave:
        datadict[key] = locals()[key]

    np.savez(os.path.join(savepath, savefname), **datadict)
    if plot:
        return fig_stas
Exemple #12
0
def minimize_loglikelihood(k_initial,
                           Q_initial,
                           mu_initial,
                           x,
                           time_res,
                           spikes,
                           usegrad=True,
                           method='CG',
                           minimize_disp=False,
                           **kwargs):
    """
    Calculate the filters that minimize the log likelihood function for a
    given set of spikes and stimulus.

    Parameters
    --------
    k_initial, Q_initial, mu_initial:
        Initial guesses for the parameters.
    x:
        The stimulus. Last axis should be temporal, and number of
        stimulus dimensions should match the initial guesses for parameters.
    time_res:
        Length of each bin (referred also as Delta, frame_duration)
    spikes:
        Binned spikes, must have the same shape as the stimulus
    usegrad:
        Whether to use gradients for optimiziation. If set to False, only
        approximated gradients will be used with the appropriate optimization
        method.
    method:
        Optimization method to use, see the Notes section in the  documentation of
        scipy.minimize for a full list.
    minimize_disp:
        Whether to print the convergence messages of the optimization function
    """
    kQmu_initial = flattenpars(k_initial, Q_initial, mu_initial)

    # Infer the filter length from the shape of the initial guesses and
    # set it globally so that other functions can also use it.
    global filter_length, stimdim
    if filter_length is None:
        filter_length = k_initial.shape[-1]
    if stimdim is None:
        if x.ndim > 1:
            stimdim = x.shape[0]
        else:
            stimdim = 1

    global sTs, xr  # So that they are reachable from gradients function
    # Initialize a N-D numpy array to keep outer products
    sTs = np.zeros((stimdim, spikes.shape[0], filter_length, filter_length))
    # Instead of iterating over each time bin, use the rolling window function
    # The expression in the brackets inverts the array.
    xr = asc.rolling_window(x, filter_length)[..., ::-1]
    # Add one extra dimension at the beginning in case the stimulus is
    # single dimensional
    xr = xr[None, ...] if x.ndim == 1 else xr
    for j in range(stimdim):
        for i in range(spikes.shape[0] - filter_length):
            x_temp = xr[j, i, :]
            sTs[j, i, :, :] = np.outer(x_temp, x_temp)

    minimizekwargs = {'options': {'disp': minimize_disp}}
    if usegrad:
        minimizekwargs.update({'jac': gradients})
    minimizekwargs.update(kwargs)

    res = minimize(loglikelihood,
                   kQmu_initial,
                   tol=1e-5,
                   method=method,
                   args=(x, spikes, time_res),
                   **minimizekwargs)
    return res
Exemple #13
0
def sigtest(spikes,
            stimulus,
            filter_length,
            ntest=500,
            confidence_level=.95,
            debug=False):
    """
    Calculate the significant components of the spike-triggered covariance
    matrix.

    Parameters
    ---------


    Returns
    -------
    significant_components:
        indices of the significant components of the spike-triggered covariance
        matrix, matching the return value of np.linalg.eigh(stc).
        Corresponds to the nth element of the eigenvalues and nth column
        of the eigenvectors.

    Example
    -------
    >>> sta, stc = calc_stca(spikes, stimulus, 20)
    >>> eigenvalues, eigenvectors = np.linalg.eigh(stc)
    >>> sig_comp_inds = sigtest(spikes, stimulus, 20)
    >>> print(sig_comp_inds)
    [0, 19]
    >>> significant_components = eigenvalues[:, sig_comp_inds].T

    """
    sta_init, stc_init = calc_stca(spikes, stimulus, filter_length)
    eigvals_init, eigvecs_init = np.linalg.eigh(stc_init)
    no_sig_comp_left = False
    significant_components = np.array([], dtype=np.int)
    # Keep track of components above and below the mean value
    # of eigenvectors to calculate the indices correctly
    ncomps_above, ncomps_below = 0, 0
    while not no_sig_comp_left:
        all_v = np.zeros((2, ntest))  # first axis: min, max eigvals

        toremove = eigvecs_init[:, significant_components].T
        rw = asc.rolling_window(stimulus, filter_length, preserve_dim=True)
        reduced_stim_matrix = project_component_out_stimulus_matrix(
            rw, toremove)
        _, stc_loop = calc_stca_from_stimulus_matrix(spikes,
                                                     reduced_stim_matrix)

        eigvals_loop, _ = np.linalg.eigh(stc_loop)
        eigvals_notzero = eigvals_loop[~np.isclose(eigvals_loop, 0, atol=1e-2)]

        for i in range(ntest):
            shifted_spikes = np.roll(spikes,
                                     np.random.randint(spikes.shape[0]))
            _, r_stc = calc_stca_from_stimulus_matrix(shifted_spikes,
                                                      reduced_stim_matrix)

            rand_v, _ = np.linalg.eigh(r_stc)
            # Exclude the zero eigenvalues corresponding to the significant components
            rand_v = np.ma.masked_values(rand_v, value=0, atol=1e-2)
            all_v[:, i] = np.array([rand_v.min(), rand_v.max()])

        low_min, high_min = confidence_interval_bootstrap(
            all_v[0], confidence_level)
        low_max, high_max = confidence_interval_bootstrap(
            all_v[1], confidence_level)

        outliers_low = np.where(eigvals_notzero < low_min)[0]
        outliers_high = np.where(eigvals_notzero > high_max)[0]

        outlier_inds = np.hstack((outliers_low, outliers_high))
        outliers = eigvals_notzero[outlier_inds]
        if len(outliers_low) + len(outliers_high) == 0:
            no_sig_comp_left = True
        else:
            dist_low_to_min = low_min - eigvals_notzero[outliers_low]
            dist_high_to_max = eigvals_notzero[outliers_high] - high_max
            dist_to_extrema = np.hstack((dist_low_to_min, dist_high_to_max))
            largest_outlier_ind = np.argmax(dist_to_extrema)

            largest_outlier = outliers[largest_outlier_ind]
            outlier_index = np.where(eigvals_notzero == largest_outlier)[0]

            # Each time a new component from below the line is added, the returned
            # index decreases by one. We correct for this.
            outlier_index = outlier_index + ncomps_below
            significant_components = np.hstack(
                (significant_components, outlier_index))

            if debug:
                print(f'Eigval {outlier_index} is a significant component')
                plt.figure()
                plt.plot(eigvals_notzero, 'ko')
                for line in [low_min, high_min, low_max, high_max]:
                    plt.axhline(line, color='red', alpha=.3)

            if outlier_index == len(eigvals_notzero) - 1 + ncomps_below:
                # Outlier larger than mean eigenvalue
                ncomps_above += 1
            elif outlier_index == 0 + ncomps_below:
                # Outlier smaller than mean eigenvalue
                ncomps_below += 1
            else:
                raise ValueError('Largest outlier found in unexpected place!')

        if len(significant_components) > 8:
            raise ValueError(
                'Number of significant components is too damn high!')
    return significant_components
Exemple #14
0
def minimize_loglikelihood(k_initial,
                           Q_initial,
                           mu_initial,
                           x,
                           time_res,
                           spikes,
                           usegrad=True,
                           method='CG',
                           minimize_disp=False,
                           **kwargs):
    """
    Calculate the filters that minimize the log likelihood function for a
    given set of spikes and stimulus.

    Parameters
    --------
    k_initial, Q_initial, mu_initial:
        Initial guesses for the parameters.
    x:
        The stimulus
    time_res:
        Length of each bin (referred also as Delta, frame_duration)
    spikes:
        Binned spikes, must have the same shape as the stimulus
    usegrad:
        Whether to use gradients for optimiziation. If set to False, only
        approximated gradients will be used with the appropriate optimization
        method.
    method:
        Optimization method to use, see the Notes section in the  documentation of
        scipy.minimize for a full list.
    minimize_disp:
        Whether to print the convergence messages of the optimization function
    """
    kQmu_initial = flattenpars(k_initial, Q_initial, mu_initial)

    # Infer the filter length from the shape of the initial guesses and
    # set it globally so that other functions can also use it.
    global filter_length
    if filter_length is None:
        filter_length = k_initial.shape[0]

    def loglikelihood(kQmu):
        """
        Define the likelihood function for GQM
        """
        # Star before an argument expands (or unpacks) the values
        P = gqm_in(*splitpars(kQmu))
        return -np.sum(spikes * P(x)) + time_res * np.sum(np.exp(P(x)))

    # Instead of iterating over each time bin, use the rolling window function
    # The expression in the brackets inverts the array.
    xr = asc.rolling_window(x, filter_length)[:, ::-1]
    # Initialize a 3D numpy array to keep outer products
    sTs = np.zeros((spikes.shape[0], filter_length, filter_length))
    for i in range(spikes.shape[0] - filter_length):
        x_temp = xr[i, :]
        sTs[i, :, :] = np.outer(x_temp, x_temp)

    def gradients(kQmu):
        """
        Calculate gradients for the log-likelihood function
        """
        k, Q, mu = splitpars(kQmu)
        P = np.exp(gqm_in(k, Q, mu)(x))
        # Fast way of calculating gradients using rolling window and einsum
        dLdk = spikes @ xr - time_res * (P @ xr)
        # Using einsum to multiply and sum along the desired axis.
        # more detailed explanation here:
        # https://stackoverflow.com/questions/26089893/understanding-numpys-einsum
        dLdq = (np.einsum('ijk,i->jk', sTs, spikes) -
                time_res * np.einsum('ijk,i->jk', sTs, P))
        dLdmu = spikes.sum() - time_res * np.sum(P)

        dL = flattenpars(dLdk, dLdq, dLdmu)
        return -dL

    minimizekwargs = {'options': {'disp': minimize_disp}}
    if usegrad:
        minimizekwargs.update({'jac': gradients})
    minimizekwargs.update(kwargs)

    res = minimize(loglikelihood,
                   kQmu_initial,
                   tol=1e-5,
                   method=method,
                   **minimizekwargs)
    return res
Exemple #15
0
nrcells = len(choosecells)

ombcoords = np.zeros((nrcells, 2))
all_spikes = np.zeros((nrcells, st.ntotal), dtype=np.int8)
#all_contrasts = np.zeros((nrcells, st.ntotal))

for i, cell in enumerate(choosecells):
    ombcoords[i, :] = moc.chkmax2ombcoord(cell, exp, ombstimnr, checkerstimnr)
    all_spikes[i, :] = st.binnedspiketimes(cell)
#    all_contrasts[i, :] = st.generatecontrast(ombcoords[i, :])

contrast = st.generatecontrast(st.texpars.noiselim/2, 100).astype(np.float32)
contrast_sum = asc.normalize(contrast.sum(axis=2), axis_inv=None)
plt.imshow(contrast_sum, cmap='Greys_r')
#%%
rw = asc.rolling_window(contrast, 20)
#rws = rw.transpose((2, 0, 1, 3))

stas = np.einsum('abcd,ec->eabd', rw, all_spikes)
stas = stas / all_spikes.sum(axis=1)[:, None, None, None]
stas_normalized = asc.normalize(stas, axis_inv=0)
plt.imshow(stas[0, ..., 0], cmap='Greys_r')
#%%
from scratch_spikeshuffler import shufflebyrow

shuffled_spikes = shufflebyrow(all_spikes)

shuffled_stas = np.einsum('abcd,ec->eabd', rw, shuffled_spikes)
shuffled_stas = shuffled_stas / all_spikes.sum(axis=1)[:, None, None, None]
shuffled_stas_normalized = asc.normalize(shuffled_stas, axis_inv=0)
plt.imshow(shuffled_stas[0, ..., 0], cmap='Greys_r')
Exemple #16
0
def minimize_loglikelihood(k_initial,
                           Q_initial,
                           x,
                           time_res,
                           spikes,
                           usegrad=True,
                           method='CG',
                           minimize_disp=False,
                           **kwargs):
    kQ_initial = flattenpars(k_initial, Q_initial)

    # Infer the filter length from the shape of the initial guesses and
    # set it globally so that other functions can also use it.
    global filter_length
    if filter_length is None:
        filter_length = k_initial.shape[0]
    # Trim away the first filter_length elements to align spikes array
    # with the output of the convolution operations
    if spikes.shape[0] == x.shape[0]:
        #        print('spikes array reshaped while fitting GQM likelihood')
        #        spikes = spikes[filter_length-1:]
        pass

    def loglikelihood(kQ):
        P = gqm_in(*splitpars(kQ))
        return -(np.sum(spikes * P(x) - time_res * np.sum(np.exp(P(x)))))

    # Instead of iterating over each time bin, generate a hankel matrix
    # from the stimulus vector and operate on that using matrix
    # multiplication like so: X @ xh , where X is a vector containing
    # some number for each time bin.

#    xh = hankel(x)[:, :filter_length]
    xr = asc.rolling_window(x, filter_length)[:, ::-1]
    sTs = np.zeros((spikes.shape[0], filter_length, filter_length))
    for i in range(spikes.shape[0] - filter_length):
        #        x_temp = x[i:i+filter_length][np.newaxis,:]
        x_temp = xr[i, :]
        sTs[i, :, :] = np.outer(x_temp, x_temp)
    # Stimulus length in seconds, found this empirically.
    k_correction = x.shape[0] * time_res * xr.sum(axis=0)
    plt.plot(np.diag(sTs.sum(axis=0)))
    plt.title('diag(sTs.sum(axis=0))')
    plt.show()
    #    import pdb; pdb.set_trace()
    q_correction = x.shape[0] * time_res * sTs.sum(
        axis=0) + np.eye(filter_length) * x.shape[0]

    #    q_correction = x.shape[0]*time_res*sTs.sum(axis=0) + np.diag(sTs.sum(axis=0))
    def gradients(kQ):
        k, Q = splitpars(kQ)
        P = np.exp(gqm_in(k, Q)(x))
        #        dLdk = np.zeros(k.shape)
        #        dLdq = np.zeros(Q.shape)
        #        dLdmu = 0
        #        for i in range(filter_length, x_mini.shape[0]):
        #            s = x[i:i+filter_length]
        #            dLdk += (spikes[i] * s -
        #                       time_res*P[i]*s)
        #            dLdq += (spikes[i] * np.outer(s,s) - time_res*P[i] * np.outer(s, s))
        #            dLdmu += spikes[i] - time_res * P[i]
        dLdk = spikes @ xr - time_res * (P @ xr)
        dLdk -= k_correction
        # Using einsum to multiply and sum along the desired axis.
        # more detailed explanation here:
        # https://stackoverflow.com/questions/26089893/understanding-numpys-einsum
        dLdq = (np.einsum('ijk,i->jk', sTs, spikes) -
                time_res * np.einsum('ijk,i->jk', sTs, P))
        dLdq -= q_correction
        #        import pdb; pdb.set_trace()

        dL = flattenpars(dLdk, dLdq)
        return -dL

    minimizekwargs = {'options': {'disp': minimize_disp}}
    if usegrad:
        minimizekwargs.update({'jac': gradients})
    minimizekwargs.update(kwargs)

    res = minimize(loglikelihood,
                   kQ_initial,
                   tol=1e-1,
                   method=method,
                   **minimizekwargs)
    return res