Exemple #1
0
        axv.set_title('Eigenvalues of Q')
        axw.set_title('Eigenvectors of Q')
        axw.set_xlabel('Time before spike [ms]')

        axv.plot(w_out, 'ko')
        eiginds = [0, 1, fl-2, fl-1]

        for ind, (eigind, w) in enumerate(zip(eiginds, w_out[eiginds])):
            axv.plot(eigind, w, 'o', color=colors[ind])
            axw.plot(t, v_out[:, eigind], color=colors[ind])

            generator = np.convolve(eigvecs[i, j, :, eigind], stimulus[j, :],
                                    mode='full')[:-st.filter_length+1]

            nonlinearity, bins = nlt.calc_nonlin(spikes, generator, nr_bins=40)
            axn.plot(bins, nonlinearity/st.frame_duration, color=colors[ind])

        axn.set_title('Nonlinearities')
        axn.set_ylabel('Firing rate [Hz]')
        axn.set_xlabel('Stimulus projection')

    plt.suptitle(f'{st.exp_foldername}\n'
                 f'{st.stimname}\n'
                 f'{st.clids[i]} {gqmlabel} corr: {cross_corr:4.2f} nsp: {spikes.sum():5.0f}')
    plt.tight_layout()
    plt.subplots_adjust(top=.85)
    #%%
    plt.savefig(os.path.join(savedir, st.clids[i]+'.svg'),
                bbox_inches='tight',
                pad_inches=0.3)
Exemple #2
0
def cca_omb_components(exp: str,
                       stim_nr: int,
                       n_components: int = 6,
                       regularization=None,
                       filter_length=None,
                       cca_solver: str = 'macke',
                       maxframes=None,
                       shufflespikes: bool = False,
                       exclude_allzero_spike_rows: bool = True,
                       savedir: str = None,
                       savefig: bool = True,
                       sort_by_nspikes: bool = True,
                       select_cells: list = None,
                       plot_first_ncells: int = None,
                       whiten: bool = False):
    """
    Analyze OMB responses using cannonical correlation analysis and plot the results.

    Parameters
    ---
    n_components:
        Number of components that will be requested from the CCA anaylsis. More numbers mean
        the algortihm will stop at a later point. That means components of analyses with fewer
        n_components are going to be identical to the first n components of the higher-number
        component analyses.
    regularization:
        The regularization parameter to be passed onto rcca.CCA. Not relevant for macke
    filter_length:
        The length of the time window to be considered in the past for the stimulus and the responses.
        Can be different for stimulus and response, if a tuple is given.
    cca_solver:
        Which CCA solver to use. Options are `rcca` and `macke`(default)
    maxframes: int
        Number of frames to load in the the experiment object. Used to avoid memory and performance
        issues.
    shufflespikes: bool
        Whether to randomize the spikes, to validate the results
    exclude_allzero_spike_rows:
        Exclude all cells which have zero spikes for the duration of the stimulus.
    savedir: str
        Custom directory to save the figures and data files. If None, will be saved in the experiment
        directory under appropritate path.
    savefig: bool
        Whether to save the figures.
    sort_by_nspikes: bool
        Wheter to sort the cell weights array by the number of spikes during the stimulus.
    select_cells: list
       A list of indexes for the subset of cells to perform the analysis for.
    plot_first_ncells: int
        Number of cells to plot in the cell plots.
    """
    if regularization is None:
        regularization = 0

    st = OMB(exp, stim_nr, maxframes=maxframes)

    if filter_length is None:
        filter_length = st.filter_length

    if type(filter_length) is int:
        filter_length = (filter_length, filter_length)

    if type(savedir) is str:
        savedir = Path(savedir)

    if savedir is None:
        savedir = st.stim_dir / 'CCA'
        savedir.mkdir(exist_ok=True, parents=True)

    spikes = st.allspikes()
    nonzerospikerows = ~np.isclose(spikes.sum(axis=1), 0)
    # Set the mean to zero for spikes
    spikes -= spikes.mean(axis=1)[:, None]

    bgsteps = st.bgsteps

    if select_cells is not None:
        if type(select_cells) is not np.array:
            select_cells = np.array(select_cells)
        spikes = spikes[select_cells]
        st.nclusters = len(select_cells)
        # Convert to list for better string representation
        # np.array is printed as "array([....])"
        # with newline characters which is problematic in filenames
        select_cells = list(select_cells)

    # Exclude rows that have no spikes throughout
    if exclude_allzero_spike_rows:
        spikes = spikes[nonzerospikerows, :]

    nspikes_percell = spikes.sum(axis=1)

    if shufflespikes:
        spikes = spikeshuffler.shufflebyrow(spikes)

    figsavename = f'{n_components=}_{shufflespikes=}_{select_cells=}_{regularization=}_{filter_length=}_{whiten=}'
    # If the file length gets too long due to the list of selected cells, summarize it.
    if len(figsavename) > 200:
        figsavename = f'{n_components=}_{shufflespikes=}_select_cells={len(select_cells)}cells-index{select_cells[0]}to{select_cells[-1]}_{regularization=}_{filter_length=}_{whiten=}'

    #sp_train, sp_test, stim_train, stim_test = train_test_split(spikes, bgsteps)

    stimulus = mft.packdims(st.bgsteps, filter_length[0])
    spikes = mft.packdims(spikes, filter_length[1])

    if cca_solver == 'rcca':
        resp_comps, stim_comps, cancorrs = cca_rcca(spikes, stimulus,
                                                    filter_length,
                                                    n_components,
                                                    regularization, whiten)

        # cells = np.swapaxes(spikes_res, 1, 0)
        # cells = cells.reshape((n_components, st.nclusters, filter_length[1]))
        # motionfilt_x = cca.ws[1][:filter_length[0]].T
        # motionfilt_y = cca.ws[1][filter_length[0]:].T
    elif cca_solver == 'macke':
        resp_comps, stim_comps, cancorrs = cca_macke(spikes, stimulus,
                                                     filter_length,
                                                     n_components)

    nsp_argsorted = np.argsort(nspikes_percell)
    resp_comps_sorted_nsp = resp_comps[:, nsp_argsorted, :]

    if sort_by_nspikes:
        resp_comps_toplot = resp_comps_sorted_nsp
    else:
        resp_comps_toplot = resp_comps

    if plot_first_ncells is not None:
        resp_comps_toplot = resp_comps_toplot[:, :plot_first_ncells, ...]

    motionfilt_r, motionfilt_theta = mft.cart2pol(stim_comps[:, 0, :],
                                                  stim_comps[:, 1, :])
    #%%
    nrows, ncols = plf.numsubplots(n_components)
    fig_cells, axes_cells = plt.subplots(nrows, ncols, figsize=(10, 10))

    for i in range(n_components):
        ax = axes_cells.flat[i]
        im = ax.imshow(resp_comps[i, :],
                       cmap='RdBu_r',
                       vmin=asc.absmin(resp_comps),
                       vmax=asc.absmax(resp_comps),
                       aspect='auto',
                       interpolation='nearest')
        ax.set_title(f'{i}')
    fig_cells.suptitle(f'Cells default order {shufflespikes=}')
    if savefig:
        fig_cells.savefig(savedir / f'{figsavename}_cells_default_order.pdf')
    plt.close(fig_cells)

    nsubplots = plf.numsubplots(n_components)
    height_list = [1, 1, 1, 3]  # ratios of the plots in each component

    # Create a time vector for the stimulus plots
    t_stim = -np.arange(0, filter_length[0] * st.frame_duration,
                        st.frame_duration)[::-1] * 1000
    t_response = -np.arange(0, filter_length[1] * st.frame_duration,
                            st.frame_duration)[::-1] * 1000
    xtick_loc_params = dict(nbins=4, steps=[2, 5, 10], integer=True)

    nsubrows = len(height_list)
    height_ratios = nsubplots[0] * height_list
    fig, axes = plt.subplots(nrows=nsubplots[0] * nsubrows,
                             ncols=nsubplots[1],
                             gridspec_kw={'height_ratios': height_ratios},
                             figsize=(11, 10))

    for row, ax_row in enumerate(axes):
        for col, ax in enumerate(ax_row):
            mode_i = int(row / nsubrows) * nsubplots[1] + col
            # ax.text(0.5, 0.5, f'{mode_i}')
            ax.set_yticks([])
            # Plot motion filters
            if row % nsubrows == 0:

                ax.plot(t_stim,
                        stim_comps[mode_i, 0, :],
                        marker='o',
                        markersize=1)
                ax.plot(t_stim,
                        stim_comps[mode_i, 1, :],
                        marker='o',
                        markersize=1)
                if col == 0:
                    ax.set_ylabel('Motion',
                                  rotation=0,
                                  ha='right',
                                  va='center')
                ax.set_ylim(stim_comps.min(), stim_comps.max())

                # Draw a horizontal line for zero and prevent rescaling of x-axis
                xlims = ax.get_xlim()
                ax.hlines(0,
                          *ax.get_xlim(),
                          colors='k',
                          linestyles='dashed',
                          alpha=0.3)
                ax.set_xlim(*xlims)

                # ax.set_title(f'Component {mode_i}', fontweight='bold')

                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))

                if not mode_i == 0 or filter_length[0] == filter_length[1]:
                    ax.xaxis.set_ticklabels([])
                else:
                    ax.tick_params(axis='x', labelsize=8)

            # Plot magnitude of motion
            elif row % nsubrows == 1:
                ax.plot(t_stim,
                        motionfilt_r[mode_i, :],
                        color='k',
                        marker='o',
                        markersize=1)
                if col == 0:
                    ax.set_ylabel('Magnitude',
                                  rotation=0,
                                  ha='right',
                                  va='center')
                ax.set_ylim(motionfilt_r.min(), motionfilt_r.max())
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
            # Plot direction of motion
            elif row % nsubrows == 2:
                ax.plot(t_stim,
                        motionfilt_theta[mode_i, :],
                        color='r',
                        marker='o',
                        markersize=1)
                if mode_i == 0:
                    ax.yaxis.set_ticks([-np.pi, 0, np.pi])
                    ax.yaxis.set_ticklabels(['-π', 0, 'π'])
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
            # Plot cell weights
            elif row % nsubrows == nsubrows - 1:
                im = ax.imshow(resp_comps_toplot[mode_i, :],
                               cmap='RdBu_r',
                               vmin=asc.absmin(resp_comps),
                               vmax=asc.absmax(resp_comps),
                               aspect='auto',
                               interpolation='nearest',
                               extent=[
                                   t_response[0], t_response[-1], 0,
                                   resp_comps_toplot.shape[1]
                               ])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
                if row == axes.shape[0] - 1:
                    ax.set_xlabel('Time before spike [ms]')
                    # ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells_toplot.shape[-1])
                    # ax.xaxis.set_ticklabels(-np.round((ax.get_xticks()*st.frame_duration), 2)[::-1])
                else:
                    ax.xaxis.set_ticklabels([])

                plf.integerticks(ax, 5, which='y')
                if col == 0:
                    ax.set_ylabel(
                        f'Cells\n{"(sorted nsp)"*sort_by_nspikes}\n{("(first " + str(plot_first_ncells)+ " cells)")*(type(plot_first_ncells) is int) }',
                        rotation=0,
                        ha='right',
                        va='center')
                else:
                    ax.yaxis.set_ticklabels([])
                if mode_i == n_components - 1:
                    plf.colorbar(im)
            # Add ticks on the right side of the plots
            if col == nsubplots[1] - 1 and row % nsubrows != nsubrows - 1:
                plf.integerticks(ax, 3, which='y')
                ax.yaxis.tick_right()

    fig.suptitle(
        f'CCA components of {st.exp_foldername}\n{shufflespikes=} {n_components=}\n{sort_by_nspikes=}\n'
        + f'{select_cells=} {regularization=} {filter_length=}')
    fig.subplots_adjust(wspace=0.1, hspace=0.3)
    if savefig:
        fig.savefig(savedir / f'{figsavename}_cellsandcomponents.pdf')
    # plt.show()
    plt.close(fig)

    #%%
    fig_corrs = plt.figure()
    plt.plot(cancorrs, 'ko')
    # plt.ylim([0.17, 0.24])
    plt.xlabel('Component index')
    plt.ylabel('Correlation')
    plt.title(f'Cannonical correlations {shufflespikes=}')
    if savefig:
        fig_corrs.savefig(savedir / f'{figsavename}_correlation_coeffs.pdf')
    # plt.show()
    plt.close(fig_corrs)

    fig_nlt, axes_nlt = plt.subplots(nrows, ncols, figsize=(10, 10))

    stim_comps_flatter = stim_comps[:n_components].reshape(
        (n_components, 2 * filter_length[0])).T
    resp_comps_flatter = resp_comps[:n_components].reshape(
        (n_components, resp_comps.shape[1] * filter_length[1])).T

    # from IPython.core.debugger import Pdb; ipdb=Pdb(); ipdb.set_trace()
    # Reshape to perform the convolution as a matrix multiplication
    generator_stim = stimulus @ stim_comps_flatter
    generator_resp = spikes @ resp_comps_flatter

    for i, ax in enumerate(axes_nlt.flatten()):

        nonlinearity, bins = nlt.calc_nonlin(generator_resp[:, i],
                                             generator_stim[:, i])
        # ax.scatter(generator_stim, generator_resp, s=1, alpha=0.5, facecolor='grey')
        ax.plot(bins, nonlinearity, 'k')
        if i == 0:
            all_nonlinearities = np.empty((n_components, *nonlinearity.shape))
            all_bins = np.empty((n_components, *bins.shape))
        all_nonlinearities[i, ...] = nonlinearity
        all_bins[i, ...] = bins

    nlt_xlims = []
    nlt_ylims = []
    for i, ax in enumerate(axes_nlt.flatten()):
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        nlt_xlims.extend(xlim)
        nlt_ylims.extend(ylim)
    nlt_maxx, nlt_minx = max(nlt_xlims), min(nlt_xlims)
    nlt_maxy, nlt_miny = max(nlt_ylims), min(nlt_ylims)

    for i, ax in enumerate(axes_nlt.flatten()):
        ax.set_xlim([nlt_minx, nlt_maxx])
        ax.set_ylim([nlt_miny, nlt_maxy])

    for i, axes_row in enumerate(axes_nlt):
        for j, ax in enumerate(axes_row):
            if i == nrows - 1:
                ax.set_xlabel('Generator (motion)')
            if j == 0:
                ax.set_ylabel('Generator (cells)')
            else:
                ax.yaxis.set_ticklabels([])
            ax.set_xlim([nlt_minx, nlt_maxx])
            ax.set_ylim([nlt_miny, nlt_maxy])

    fig_nlt.suptitle(f'Nonlinearities\n{figsavename}')
    if savefig:
        fig_nlt.savefig(savedir / f'{figsavename}_nonlinearity.png')
    plt.close(fig_nlt)
    keystosave = [
        'n_components', 'resp_comps', 'stim_comps', 'motionfilt_r',
        'motionfilt_theta', 'resp_comps_sorted_nsp', 'select_cells',
        'regularization', 'filter_length', 'all_nonlinearities', 'all_bins',
        'cca_solver'
    ]
    datadict = dict()
    for key in keystosave:
        datadict[key] = locals()[key]
    np.savez(savedir / figsavename, **datadict)
Exemple #3
0
def OMBanalyzer(exp_name, stimnr, plotall=False, nr_bins=20):
    """
    Analyze responses to object moving background stimulus. STA and STC
    are calculated.

    Note that there are additional functions that make use of the
    OMB class. This function was written before the OMB class existed
    """
    # TODO
    # Add iteration over multiple stimuli

    exp_dir = iof.exp_dir_fixer(exp_name)
    exp_name = os.path.split(exp_dir)[-1]
    stimname = iof.getstimname(exp_dir, stimnr)

    parameters = asc.read_parameters(exp_name, stimnr)
    assert parameters['stimulus_type'] == 'objectsmovingbackground'
    stimframes = parameters.get('stimFrames', 108000)
    preframes = parameters.get('preFrames', 200)
    nblinks = parameters.get('Nblinks', 2)

    seed = parameters.get('seed', -10000)
    seed2 = parameters.get('objseed', -1000)

    stepsize = parameters.get('stepsize', 2)

    ntotal = int(stimframes / nblinks)

    clusters, metadata = asc.read_spikesheet(exp_name)

    refresh_rate = metadata['refresh_rate']
    filter_length, frametimings = asc.ft_nblinks(exp_name, stimnr, nblinks,
                                                 refresh_rate)
    frame_duration = np.ediff1d(frametimings).mean()
    frametimings = frametimings[:-1]

    if ntotal != frametimings.shape[0]:
        print(f'For {exp_name}\nstimulus {stimname} :\n'
              f'Number of frames specified in the parameters file ({ntotal}'
              f' frames) and frametimings ({frametimings.shape[0]}) do not'
              ' agree!'
              ' The stimulus was possibly interrupted during recording.'
              ' ntotal is changed to match actual frametimings.')
        ntotal = frametimings.shape[0]

    # Generate the numbers to be used for reconstructing the motion
    # ObjectsMovingBackground.cpp line 174, steps are generated in an
    # alternating fashion. We can generate all of the numbers at once
    # (total lengths is defined by stimFrames) and then assign
    # to x and y directions. Although there is more
    # stuff around line 538
    randnrs, seed = randpy.gasdev(seed, ntotal * 2)
    randnrs = np.array(randnrs) * stepsize

    xsteps = randnrs[::2]
    ysteps = randnrs[1::2]

    clusterids = plf.clusters_to_ids(clusters)

    all_spikes = np.empty((clusters.shape[0], ntotal))
    for i, (cluster, channel, _) in enumerate(clusters):
        spiketimes = asc.read_raster(exp_name, stimnr, cluster, channel)
        spikes = asc.binspikes(spiketimes, frametimings)
        all_spikes[i, :] = spikes

    # Collect STA for x and y movement in one array
    stas = np.zeros((clusters.shape[0], 2, filter_length))
    stc_x = np.zeros((clusters.shape[0], filter_length, filter_length))
    stc_y = np.zeros((clusters.shape[0], filter_length, filter_length))
    t = np.arange(filter_length) * 1000 / refresh_rate * nblinks
    for k in range(filter_length, ntotal - filter_length + 1):
        x_mini = xsteps[k - filter_length + 1:k + 1][::-1]
        y_mini = ysteps[k - filter_length + 1:k + 1][::-1]
        for i, (cluster, channel, _) in enumerate(clusters):
            if all_spikes[i, k] != 0:
                stas[i, 0, :] += all_spikes[i, k] * x_mini
                stas[i, 1, :] += all_spikes[i, k] * y_mini
                # Calculate non-centered STC (Cantrell et al., 2010)
                stc_x[i, :, :] += all_spikes[i, k] * calc_covar(x_mini)
                stc_y[i, :, :] += all_spikes[i, k] * calc_covar(y_mini)

    eigvals_x = np.zeros((clusters.shape[0], filter_length))
    eigvals_y = np.zeros((clusters.shape[0], filter_length))
    eigvecs_x = np.zeros((clusters.shape[0], filter_length, filter_length))
    eigvecs_y = np.zeros((clusters.shape[0], filter_length, filter_length))

    bins_x = np.zeros((clusters.shape[0], nr_bins))
    bins_y = np.zeros((clusters.shape[0], nr_bins))
    spikecount_x = np.zeros(bins_x.shape)
    spikecount_y = np.zeros(bins_x.shape)
    generators_x = np.zeros(all_spikes.shape)
    generators_y = np.zeros(all_spikes.shape)
    # Normalize STAs and STCs with respect to spike numbers
    for i in range(clusters.shape[0]):
        totalspikes = all_spikes.sum(axis=1)[i]
        stas[i, :, :] = stas[i, :, :] / totalspikes
        stc_x[i, :, :] = stc_x[i, :, :] / totalspikes
        stc_y[i, :, :] = stc_y[i, :, :] / totalspikes
        try:
            eigvals_x[i, :], eigvecs_x[i, :, :] = np.linalg.eigh(
                stc_x[i, :, :])
            eigvals_y[i, :], eigvecs_y[i, :, :] = np.linalg.eigh(
                stc_y[i, :, :])
        except np.linalg.LinAlgError:
            continue
        # Calculate the generator signals and nonlinearities
        generators_x[i, :] = np.convolve(eigvecs_x[i, :, -1],
                                         xsteps,
                                         mode='full')[:-filter_length + 1]
        generators_y[i, :] = np.convolve(eigvecs_y[i, :, -1],
                                         ysteps,
                                         mode='full')[:-filter_length + 1]
        spikecount_x[i, :], bins_x[i, :] = nlt.calc_nonlin(
            all_spikes[i, :], generators_x[i, :], nr_bins)
        spikecount_y[i, :], bins_y[i, :] = nlt.calc_nonlin(
            all_spikes[i, :], generators_y[i, :], nr_bins)
    savepath = os.path.join(exp_dir, 'data_analysis', stimname)
    if not os.path.isdir(savepath):
        os.makedirs(savepath, exist_ok=True)

    # Calculated based on last eigenvector
    magx = eigvecs_x[:, :, -1].sum(axis=1)
    magy = eigvecs_y[:, :, -1].sum(axis=1)
    r_ = np.sqrt(magx**2 + magy**2)
    theta_ = np.arctan2(magy, magx)
    # To draw the vectors starting from origin, insert zeros every other element
    r = np.zeros(r_.shape[0] * 2)
    theta = np.zeros(theta_.shape[0] * 2)
    r[1::2] = r_
    theta[1::2] = theta_
    plt.polar(theta, r)
    plt.gca().set_xticks(np.pi / 180 * np.array([0, 90, 180, 270]))
    plt.title(f'Population plot for motion STAs\n{exp_name}')
    plt.savefig(os.path.join(savepath, 'population.svg'))
    if plotall:
        plt.show()
    plt.close()

    for i in range(stas.shape[0]):
        stax = stas[i, 0, :]
        stay = stas[i, 1, :]
        ax1 = plt.subplot(211)
        ax1.plot(t, stax, label=r'STA$_{\rm X}$')
        ax1.plot(t, stay, label=r'STA$_{\rm Y}$')
        ax1.plot(t, eigvecs_x[i, :, -1], label='Eigenvector_X 0')
        ax1.plot(t, eigvecs_y[i, :, -1], label='Eigenvector_Y 0')
        plt.legend(fontsize='x-small')

        ax2 = plt.subplot(4, 4, 9)
        ax3 = plt.subplot(4, 4, 13)
        ax2.set_yticks([])
        ax2.set_xticklabels([])
        ax3.set_yticks([])
        ax2.set_title('Eigenvalues', size='small')
        ax2.plot(eigvals_x[i, :],
                 'o',
                 markerfacecolor='C0',
                 markersize=4,
                 markeredgewidth=0)
        ax3.plot(eigvals_y[i, :],
                 'o',
                 markerfacecolor='C1',
                 markersize=4,
                 markeredgewidth=0)
        ax4 = plt.subplot(2, 3, 5)
        ax4.plot(bins_x[i, :], spikecount_x[i, :] / frame_duration)
        ax4.plot(bins_y[i, :], spikecount_y[i, :] / frame_duration)
        ax4.set_ylabel('Firing rate [Hz]')
        ax4.set_title('Nonlinearities', size='small')
        plf.spineless([ax1, ax2, ax3, ax4], 'tr')
        ax5 = plt.subplot(2, 3, 6, projection='polar')
        ax5.plot(theta, r, color='k', alpha=.3)
        ax5.plot(theta[2 * i:2 * i + 2], r[2 * i:2 * i + 2], lw=3)
        ax5.set_xticklabels(['0', '', '', '', '180', '', '270', ''])
        ax5.set_title('Vector sum of X and Y STCs', size='small')
        plt.suptitle(f'{exp_name}\n{stimname}\n{clusterids[i]}')
        plt.subplots_adjust(hspace=.4)
        plt.savefig(os.path.join(savepath, clusterids[i] + '.svg'),
                    bbox_inches='tight')
        if plotall:
            plt.show()
        plt.close()

    keystosave = [
        'nblinks', 'all_spikes', 'clusters', 'frame_duration', 'eigvals_x',
        'eigvals_y', 'eigvecs_x', 'eigvecs_y', 'filter_length', 'magx', 'magy',
        'ntotal', 'r', 'theta', 'stas', 'stc_x', 'stc_y', 'bins_x', 'bins_y',
        'nr_bins', 'spikecount_x', 'spikecount_y', 'generators_x',
        'generators_y', 't'
    ]
    datadict = {}

    for key in keystosave:
        datadict[key] = locals()[key]

    npzfpath = os.path.join(savepath, str(stimnr) + '_data')
    np.savez(npzfpath, **datadict)
Exemple #4
0
                 omb_stas[i, j, :],
                 label='OMB STA',
                 color='k',
                 alpha=.6,
                 ls='dashed')
        if j == 0:
            axk.set_title('Linear filters')
            axnlt.set_title('Nonlinearity')

        axk.set_xlabel('Time before spike [ms]')
        axk.set_ylabel(f'{plotlabels[j]}')

        generator = np.convolve(kall[i, j, :], stimulus[j, :],
                                mode='full')[:-st.filter_length + 1]
        nonlinearity, bins = nlt.calc_nonlin(all_spikes[i, :],
                                             generator,
                                             nr_bins=40)
        axnlt.plot(bins, nonlinearity / st.frame_duration)
        axnlt.set_xlabel('Stimulus projection')
        axnlt.set_ylabel('Firing rate [spikes/s]')

    plt.tight_layout()
    plt.subplots_adjust(top=.80)
    fig.suptitle(
        f'{st.exp_foldername} \n {st.stimname} \n'
        f'{st.clids[i]} {glmlabel} mu: {muall[i]:4.2f} '
        f'corr: {cross_corrs[i]:4.2f} nsp: {all_spikes[i, :].sum():5.0f}')
    plt.show()
    fig.savefig(os.path.join(savepath, f'{st.clids[i]}.svg'))
    plt.close()
Exemple #5
0
# Euclidian norm
motionstas_norm = motionstas / np.sqrt((motionstas**2).sum(axis=-1))[:, :, None]

bgsteps = st.bgsteps / np.sqrt(st.bgsteps.var())
rw = asc.rolling_window(bgsteps, st.filter_length)


steps_proj = np.einsum('abc,bdc->ad', motionstas_norm, rw)

nbins = 15
bins = np.zeros((nrcells, nbins))
nlts = np.zeros((bins.shape))

for i in range(nrcells):
    nlts[i, :], bins[i, :]  = nlt.calc_nonlin(all_spikes[i, :], steps_proj[i, :], nr_bins=nbins)

nlts *= st.refresh_rate

#%%
pd = steps_proj > 0.5
npd = steps_proj < -0.5

pd_spikes = all_spikes.copy()
npd_spikes = all_spikes.copy()

# Exclude spikes that are not in desired direction to zero
pd_spikes[~pd] = 0
npd_spikes[~npd] = 0

#%%
Exemple #6
0
generator = np.convolve(stim, linear_filter[::-1], 'full')[:-filter_length+1]
firing_rate = nonlinearity(generator)

spikes = np.random.poisson(firing_rate)

#%%  Recovery
sta = np.zeros(filter_length)

for i, spike in enumerate(spikes):
    if i >= filter_length:
        sta += stim[i-filter_length+1:i+1]*spike
sta /= spikes.sum()

plt.plot(linear_filter, label='Filter')
plt.plot(sta, label='Spike-triggered average')
plt.legend(fontsize='x-small')
plt.show()
#%%
import nonlinearity as nlt

regenerator = np.convolve(sta[::-1], stim, 'full')[:-filter_length+1]

nonlin, bins = nlt.calc_nonlin(spikes, regenerator)
binsold, nonlinold = nlt._calc_nonlin(spikes, regenerator)

plt.plot(bins, nonlin, label='new')
plt.plot(binsold, nonlinold, label='old')
plt.plot(bins, nonlinearity(bins), label='real_nonlinearity')
plt.legend()