Esempio n. 1
0
 def show(self, cm='RdBu', onlymax=True):
     a = self.sta
     vmax = a[tuple(self.maxi)]
     vmin = -vmax
     inds = self.sta.shape[-1]
     rows, columns = plf.numsubplots(inds)
     for i in range(inds):
         ax = plt.subplot(rows, columns, i + 1)
         ax.imshow(a[:, :, i], vmax=vmax, vmin=vmin, cmap=cm)
     plt.show()
Esempio n. 2
0
              copy=True)

    cca.fit(spikes, stimulus)

    x, y = cca.transform(spikes, stimulus)
    # x, y = x, y

    cells = cca.x_weights_.T
    cells = cells.reshape((n_components, st.nclusters, filter_length))

    motionfilt_x = cca.y_weights_[:filter_length].T
    motionfilt_y = cca.y_weights_[filter_length:].T

    motionfilt_r, motionfilt_theta = cart2pol(motionfilt_x, motionfilt_y)
    #%%
    nsubplots = plf.numsubplots(n_components)
    if singlecell:
        nsubplots = (1, 1)
    height_list = [1, 1, 1, 3]  # ratios of the plots in each component
    nsubrows = len(height_list)
    height_ratios = nsubplots[0] * height_list
    fig, axes = plt.subplots(nrows=nsubplots[0] * nsubrows,
                             ncols=nsubplots[1],
                             gridspec_kw={'height_ratios': height_ratios},
                             figsize=(9, 10))
    if singlecell:
        axes = axes[:, None]
    for row, ax_row in enumerate(axes):
        for col, ax in enumerate(ax_row):
            mode_i = int(row / nsubrows) * nsubplots[1] + col
            if singlecell: mode_i -= 1
Esempio n. 3
0
def stripestim(exp_name):
    if '20180124' in exp_name or '20180207' in exp_name:
        stripeflicker = [6, 12]
    elif '20180118' in exp_name:
        stripeflicker = [7, 14]
    return stripeflicker


exps = ['20180118', '20180124', '20180207']

data = np.load('/home/ycan/Documents/thesis/analysis_auxillary_files/'
               'thesis_csiplotting.npz')
include = data['include']
cells = data['cells']
groups = data['groups']

all_fits = np.empty((*cells.shape, 73))

for exp in exps:
    stim = stripestim(exp)
    fits_m = np.array(iof.load(exp, stim[0])['fits'])
    fits_p = np.array(iof.load(exp, stim[1])['fits'])

p = plf.numsubplots(nrcells)
axes = plt.subplots(*p)[1].ravel()
for i in range(nrcells):
    ax = axes[i]
    ax.plot(fits_m[i, :])
    ax.plot(fits_p[i, :])
    plf.spineless(ax)
    ax.set_axis_off()
Esempio n. 4
0
def cca_omb_components(exp: str,
                       stim_nr: int,
                       n_components: int = 6,
                       regularization=None,
                       filter_length=None,
                       cca_solver: str = 'macke',
                       maxframes=None,
                       shufflespikes: bool = False,
                       exclude_allzero_spike_rows: bool = True,
                       savedir: str = None,
                       savefig: bool = True,
                       sort_by_nspikes: bool = True,
                       select_cells: list = None,
                       plot_first_ncells: int = None,
                       whiten: bool = False):
    """
    Analyze OMB responses using cannonical correlation analysis and plot the results.

    Parameters
    ---
    n_components:
        Number of components that will be requested from the CCA anaylsis. More numbers mean
        the algortihm will stop at a later point. That means components of analyses with fewer
        n_components are going to be identical to the first n components of the higher-number
        component analyses.
    regularization:
        The regularization parameter to be passed onto rcca.CCA. Not relevant for macke
    filter_length:
        The length of the time window to be considered in the past for the stimulus and the responses.
        Can be different for stimulus and response, if a tuple is given.
    cca_solver:
        Which CCA solver to use. Options are `rcca` and `macke`(default)
    maxframes: int
        Number of frames to load in the the experiment object. Used to avoid memory and performance
        issues.
    shufflespikes: bool
        Whether to randomize the spikes, to validate the results
    exclude_allzero_spike_rows:
        Exclude all cells which have zero spikes for the duration of the stimulus.
    savedir: str
        Custom directory to save the figures and data files. If None, will be saved in the experiment
        directory under appropritate path.
    savefig: bool
        Whether to save the figures.
    sort_by_nspikes: bool
        Wheter to sort the cell weights array by the number of spikes during the stimulus.
    select_cells: list
       A list of indexes for the subset of cells to perform the analysis for.
    plot_first_ncells: int
        Number of cells to plot in the cell plots.
    """
    if regularization is None:
        regularization = 0

    st = OMB(exp, stim_nr, maxframes=maxframes)

    if filter_length is None:
        filter_length = st.filter_length

    if type(filter_length) is int:
        filter_length = (filter_length, filter_length)

    if type(savedir) is str:
        savedir = Path(savedir)

    if savedir is None:
        savedir = st.stim_dir / 'CCA'
        savedir.mkdir(exist_ok=True, parents=True)

    spikes = st.allspikes()
    nonzerospikerows = ~np.isclose(spikes.sum(axis=1), 0)
    # Set the mean to zero for spikes
    spikes -= spikes.mean(axis=1)[:, None]

    bgsteps = st.bgsteps

    if select_cells is not None:
        if type(select_cells) is not np.array:
            select_cells = np.array(select_cells)
        spikes = spikes[select_cells]
        st.nclusters = len(select_cells)
        # Convert to list for better string representation
        # np.array is printed as "array([....])"
        # with newline characters which is problematic in filenames
        select_cells = list(select_cells)

    # Exclude rows that have no spikes throughout
    if exclude_allzero_spike_rows:
        spikes = spikes[nonzerospikerows, :]

    nspikes_percell = spikes.sum(axis=1)

    if shufflespikes:
        spikes = spikeshuffler.shufflebyrow(spikes)

    figsavename = f'{n_components=}_{shufflespikes=}_{select_cells=}_{regularization=}_{filter_length=}_{whiten=}'
    # If the file length gets too long due to the list of selected cells, summarize it.
    if len(figsavename) > 200:
        figsavename = f'{n_components=}_{shufflespikes=}_select_cells={len(select_cells)}cells-index{select_cells[0]}to{select_cells[-1]}_{regularization=}_{filter_length=}_{whiten=}'

    #sp_train, sp_test, stim_train, stim_test = train_test_split(spikes, bgsteps)

    stimulus = mft.packdims(st.bgsteps, filter_length[0])
    spikes = mft.packdims(spikes, filter_length[1])

    if cca_solver == 'rcca':
        resp_comps, stim_comps, cancorrs = cca_rcca(spikes, stimulus,
                                                    filter_length,
                                                    n_components,
                                                    regularization, whiten)

        # cells = np.swapaxes(spikes_res, 1, 0)
        # cells = cells.reshape((n_components, st.nclusters, filter_length[1]))
        # motionfilt_x = cca.ws[1][:filter_length[0]].T
        # motionfilt_y = cca.ws[1][filter_length[0]:].T
    elif cca_solver == 'macke':
        resp_comps, stim_comps, cancorrs = cca_macke(spikes, stimulus,
                                                     filter_length,
                                                     n_components)

    nsp_argsorted = np.argsort(nspikes_percell)
    resp_comps_sorted_nsp = resp_comps[:, nsp_argsorted, :]

    if sort_by_nspikes:
        resp_comps_toplot = resp_comps_sorted_nsp
    else:
        resp_comps_toplot = resp_comps

    if plot_first_ncells is not None:
        resp_comps_toplot = resp_comps_toplot[:, :plot_first_ncells, ...]

    motionfilt_r, motionfilt_theta = mft.cart2pol(stim_comps[:, 0, :],
                                                  stim_comps[:, 1, :])
    #%%
    nrows, ncols = plf.numsubplots(n_components)
    fig_cells, axes_cells = plt.subplots(nrows, ncols, figsize=(10, 10))

    for i in range(n_components):
        ax = axes_cells.flat[i]
        im = ax.imshow(resp_comps[i, :],
                       cmap='RdBu_r',
                       vmin=asc.absmin(resp_comps),
                       vmax=asc.absmax(resp_comps),
                       aspect='auto',
                       interpolation='nearest')
        ax.set_title(f'{i}')
    fig_cells.suptitle(f'Cells default order {shufflespikes=}')
    if savefig:
        fig_cells.savefig(savedir / f'{figsavename}_cells_default_order.pdf')
    plt.close(fig_cells)

    nsubplots = plf.numsubplots(n_components)
    height_list = [1, 1, 1, 3]  # ratios of the plots in each component

    # Create a time vector for the stimulus plots
    t_stim = -np.arange(0, filter_length[0] * st.frame_duration,
                        st.frame_duration)[::-1] * 1000
    t_response = -np.arange(0, filter_length[1] * st.frame_duration,
                            st.frame_duration)[::-1] * 1000
    xtick_loc_params = dict(nbins=4, steps=[2, 5, 10], integer=True)

    nsubrows = len(height_list)
    height_ratios = nsubplots[0] * height_list
    fig, axes = plt.subplots(nrows=nsubplots[0] * nsubrows,
                             ncols=nsubplots[1],
                             gridspec_kw={'height_ratios': height_ratios},
                             figsize=(11, 10))

    for row, ax_row in enumerate(axes):
        for col, ax in enumerate(ax_row):
            mode_i = int(row / nsubrows) * nsubplots[1] + col
            # ax.text(0.5, 0.5, f'{mode_i}')
            ax.set_yticks([])
            # Plot motion filters
            if row % nsubrows == 0:

                ax.plot(t_stim,
                        stim_comps[mode_i, 0, :],
                        marker='o',
                        markersize=1)
                ax.plot(t_stim,
                        stim_comps[mode_i, 1, :],
                        marker='o',
                        markersize=1)
                if col == 0:
                    ax.set_ylabel('Motion',
                                  rotation=0,
                                  ha='right',
                                  va='center')
                ax.set_ylim(stim_comps.min(), stim_comps.max())

                # Draw a horizontal line for zero and prevent rescaling of x-axis
                xlims = ax.get_xlim()
                ax.hlines(0,
                          *ax.get_xlim(),
                          colors='k',
                          linestyles='dashed',
                          alpha=0.3)
                ax.set_xlim(*xlims)

                # ax.set_title(f'Component {mode_i}', fontweight='bold')

                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))

                if not mode_i == 0 or filter_length[0] == filter_length[1]:
                    ax.xaxis.set_ticklabels([])
                else:
                    ax.tick_params(axis='x', labelsize=8)

            # Plot magnitude of motion
            elif row % nsubrows == 1:
                ax.plot(t_stim,
                        motionfilt_r[mode_i, :],
                        color='k',
                        marker='o',
                        markersize=1)
                if col == 0:
                    ax.set_ylabel('Magnitude',
                                  rotation=0,
                                  ha='right',
                                  va='center')
                ax.set_ylim(motionfilt_r.min(), motionfilt_r.max())
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
            # Plot direction of motion
            elif row % nsubrows == 2:
                ax.plot(t_stim,
                        motionfilt_theta[mode_i, :],
                        color='r',
                        marker='o',
                        markersize=1)
                if mode_i == 0:
                    ax.yaxis.set_ticks([-np.pi, 0, np.pi])
                    ax.yaxis.set_ticklabels(['-π', 0, 'π'])
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
            # Plot cell weights
            elif row % nsubrows == nsubrows - 1:
                im = ax.imshow(resp_comps_toplot[mode_i, :],
                               cmap='RdBu_r',
                               vmin=asc.absmin(resp_comps),
                               vmax=asc.absmax(resp_comps),
                               aspect='auto',
                               interpolation='nearest',
                               extent=[
                                   t_response[0], t_response[-1], 0,
                                   resp_comps_toplot.shape[1]
                               ])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
                if row == axes.shape[0] - 1:
                    ax.set_xlabel('Time before spike [ms]')
                    # ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells_toplot.shape[-1])
                    # ax.xaxis.set_ticklabels(-np.round((ax.get_xticks()*st.frame_duration), 2)[::-1])
                else:
                    ax.xaxis.set_ticklabels([])

                plf.integerticks(ax, 5, which='y')
                if col == 0:
                    ax.set_ylabel(
                        f'Cells\n{"(sorted nsp)"*sort_by_nspikes}\n{("(first " + str(plot_first_ncells)+ " cells)")*(type(plot_first_ncells) is int) }',
                        rotation=0,
                        ha='right',
                        va='center')
                else:
                    ax.yaxis.set_ticklabels([])
                if mode_i == n_components - 1:
                    plf.colorbar(im)
            # Add ticks on the right side of the plots
            if col == nsubplots[1] - 1 and row % nsubrows != nsubrows - 1:
                plf.integerticks(ax, 3, which='y')
                ax.yaxis.tick_right()

    fig.suptitle(
        f'CCA components of {st.exp_foldername}\n{shufflespikes=} {n_components=}\n{sort_by_nspikes=}\n'
        + f'{select_cells=} {regularization=} {filter_length=}')
    fig.subplots_adjust(wspace=0.1, hspace=0.3)
    if savefig:
        fig.savefig(savedir / f'{figsavename}_cellsandcomponents.pdf')
    # plt.show()
    plt.close(fig)

    #%%
    fig_corrs = plt.figure()
    plt.plot(cancorrs, 'ko')
    # plt.ylim([0.17, 0.24])
    plt.xlabel('Component index')
    plt.ylabel('Correlation')
    plt.title(f'Cannonical correlations {shufflespikes=}')
    if savefig:
        fig_corrs.savefig(savedir / f'{figsavename}_correlation_coeffs.pdf')
    # plt.show()
    plt.close(fig_corrs)

    fig_nlt, axes_nlt = plt.subplots(nrows, ncols, figsize=(10, 10))

    stim_comps_flatter = stim_comps[:n_components].reshape(
        (n_components, 2 * filter_length[0])).T
    resp_comps_flatter = resp_comps[:n_components].reshape(
        (n_components, resp_comps.shape[1] * filter_length[1])).T

    # from IPython.core.debugger import Pdb; ipdb=Pdb(); ipdb.set_trace()
    # Reshape to perform the convolution as a matrix multiplication
    generator_stim = stimulus @ stim_comps_flatter
    generator_resp = spikes @ resp_comps_flatter

    for i, ax in enumerate(axes_nlt.flatten()):

        nonlinearity, bins = nlt.calc_nonlin(generator_resp[:, i],
                                             generator_stim[:, i])
        # ax.scatter(generator_stim, generator_resp, s=1, alpha=0.5, facecolor='grey')
        ax.plot(bins, nonlinearity, 'k')
        if i == 0:
            all_nonlinearities = np.empty((n_components, *nonlinearity.shape))
            all_bins = np.empty((n_components, *bins.shape))
        all_nonlinearities[i, ...] = nonlinearity
        all_bins[i, ...] = bins

    nlt_xlims = []
    nlt_ylims = []
    for i, ax in enumerate(axes_nlt.flatten()):
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        nlt_xlims.extend(xlim)
        nlt_ylims.extend(ylim)
    nlt_maxx, nlt_minx = max(nlt_xlims), min(nlt_xlims)
    nlt_maxy, nlt_miny = max(nlt_ylims), min(nlt_ylims)

    for i, ax in enumerate(axes_nlt.flatten()):
        ax.set_xlim([nlt_minx, nlt_maxx])
        ax.set_ylim([nlt_miny, nlt_maxy])

    for i, axes_row in enumerate(axes_nlt):
        for j, ax in enumerate(axes_row):
            if i == nrows - 1:
                ax.set_xlabel('Generator (motion)')
            if j == 0:
                ax.set_ylabel('Generator (cells)')
            else:
                ax.yaxis.set_ticklabels([])
            ax.set_xlim([nlt_minx, nlt_maxx])
            ax.set_ylim([nlt_miny, nlt_maxy])

    fig_nlt.suptitle(f'Nonlinearities\n{figsavename}')
    if savefig:
        fig_nlt.savefig(savedir / f'{figsavename}_nonlinearity.png')
    plt.close(fig_nlt)
    keystosave = [
        'n_components', 'resp_comps', 'stim_comps', 'motionfilt_r',
        'motionfilt_theta', 'resp_comps_sorted_nsp', 'select_cells',
        'regularization', 'filter_length', 'all_nonlinearities', 'all_bins',
        'cca_solver'
    ]
    datadict = dict()
    for key in keystosave:
        datadict[key] = locals()[key]
    np.savez(savedir / figsavename, **datadict)
Esempio n. 5
0
cca.fit(spikes.T, stimulus)

x, y = cca.transform(spikes.T, stimulus)
x, y = x.T, y


#%%
import matplotlib.pyplot as plt

cells = cca.x_weights_.T

motionfilt_x = cca.y_weights_[:filter_length].T
motionfilt_y = cca.y_weights_[filter_length:].T

fig, axes = plt.subplots(*plf.numsubplots(n_components))

for i, ax in enumerate(axes.flat):
    ax.plot(motionfilt_x[i, :])
    ax.plot(motionfilt_y[i, :])
    ax.set_ylim(cca.y_weights_.min(), cca.y_weights_.max())
fig.suptitle('Stimulus components for each CCA filter')
plt.show()

im = plt.imshow(cells, cmap='RdBu_r', vmin=asc.absmin(cells), vmax=asc.absmax(cells))
plt.ylabel('Components')
plt.xlabel('Cells')
plf.colorbar(im)
plt.show()

Esempio n. 6
0
@author: ycan
"""

import plotfuncs as plf
import matplotlib.pyplot as plt
import miscfuncs as msc
import iofuncs as iof

data = iof.load('20180124', 12)
index = 5
sta = data['stas'][index]
max_i = data['max_inds'][index]
sta, max_i = msc.cutstripe(sta, max_i, 30)

a = 'Accent, Accent_r, Blues, Blues_r, BrBG, BrBG_r, BuGn, BuGn_r, BuPu, BuPu_r, CMRmap, CMRmap_r, Dark2, Dark2_r, GnBu, GnBu_r, Greens, Greens_r, Greys, Greys_r, OrRd, OrRd_r, Oranges, Oranges_r, PRGn, PRGn_r, Paired, Paired_r, Pastel1, Pastel1_r, Pastel2, Pastel2_r, PiYG, PiYG_r, PuBu, PuBuGn, PuBuGn_r, PuBu_r, PuOr, PuOr_r, PuRd, PuRd_r, Purples, Purples_r, RdBu, RdBu_r, RdGy, RdGy_r, RdPu, RdPu_r, RdYlBu, RdYlBu_r, RdYlGn, RdYlGn_r, Reds, Reds_r, Set1, Set1_r, Set2, Set2_r, Set3, Set3_r, Spectral, Spectral_r, Vega10, Vega10_r, Vega20, Vega20_r, Vega20b, Vega20b_r, Vega20c, Vega20c_r, Wistia, Wistia_r, YlGn, YlGnBu, YlGnBu_r, YlGn_r, YlOrBr, YlOrBr_r, YlOrRd, YlOrRd_r, afmhot, afmhot_r, autumn, autumn_r, binary, binary_r, bone, bone_r, brg, brg_r, bwr, bwr_r, cool, cool_r, coolwarm, coolwarm_r, copper, copper_r, cubehelix, cubehelix_r, flag, flag_r, gist_earth, gist_earth_r, gist_gray, gist_gray_r, gist_heat, gist_heat_r, gist_ncar, gist_ncar_r, gist_rainbow, gist_rainbow_r, gist_stern, gist_stern_r, gist_yarg, gist_yarg_r, gnuplot, gnuplot2, gnuplot2_r, gnuplot_r, gray, gray_r, hot, hot_r, hsv, hsv_r, inferno, inferno_r, jet, jet_r, magma, magma_r, nipy_spectral, nipy_spectral_r, ocean, ocean_r, pink, pink_r, plasma, plasma_r, prism, prism_r, rainbow, rainbow_r, seismic, seismic_r, spectral, spectral_r, spring, spring_r, summer, summer_r, tab10, tab10_r, tab20, tab20_r, tab20b, tab20b_r, tab20c, tab20c_r, terrain, terrain_r, viridis, viridis_r, winter, winter_r'
b = a.split(',')
c = [i.strip(' ') for i in b if not i.endswith('_r')]

c = ['bwr_r', 'RdBu', 'seismic_r', 'bwr', 'RdBu_r', 'seismic']
dims = plf.numsubplots(len(c))
plt.figure(figsize=(20, 20))
for i, cm in enumerate(c):
    ax = plt.subplot(dims[0], dims[1], i + 1)
    im = plf.stashow(sta, ax, cmap=cm, ticks=[])
    plt.axis('off')
    im.axes.get_xaxis().set_visible(False)
    im.axes.get_yaxis().set_visible(False)
    ax.set_title(cm, size='x-small')
plt.savefig('cmaps.svg', bbox_inches='tight')
plt.close()
Esempio n. 7
0
def plot_checker_stas(exp_name, stim_nr, filename=None):
    """
    Plot and save all STAs from checkerflicker analysis. The plots
    will be saved in a new folder called STAs under the data analysis
    path of the stimulus.

    <exp_dir>/data_analysis/<stim_nr>_*/<stim_nr>_data.h5 file is
    used by default. If a different file is to be used, filename
    should be supplied.
    """

    from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar

    exp_dir = iof.exp_dir_fixer(exp_name)
    stim_nr = str(stim_nr)
    if filename:
        filename = str(filename)

    _, metadata = asc.read_spikesheet(exp_dir)
    px_size = metadata['pixel_size(um)']

    if not filename:
        savefolder = 'STAs'
        label = ''
    else:
        label = filename.strip('.npz')
        savefolder = 'STAs_' + label

    data = iof.load(exp_name, stim_nr, fname=filename)

    clusters = data['clusters']
    stas = data['stas']
    filter_length = data['filter_length']
    stx_h = data['stx_h']
    exp_name = data['exp_name']
    stimname = data['stimname']

    for j in range(clusters.shape[0]):
        a = stas[j]
        subplot_arr = plf.numsubplots(filter_length)
        sta_max = np.max(np.abs([np.max(a), np.min(a)]))
        sta_min = -sta_max
        plt.figure(dpi=250)
        for i in range(filter_length):
            ax = plt.subplot(subplot_arr[0], subplot_arr[1], i + 1)
            im = ax.imshow(a[:, :, i],
                           vmin=sta_min,
                           vmax=sta_max,
                           cmap=iof.config('colormap'))
            ax.set_aspect('equal')
            plt.axis('off')
            if i == 0:
                scalebar = AnchoredSizeBar(ax.transData,
                                           10,
                                           '{} µm'.format(10 * stx_h *
                                                          px_size),
                                           'lower left',
                                           pad=0,
                                           color='k',
                                           frameon=False,
                                           size_vertical=1)
                ax.add_artist(scalebar)
            if i == filter_length - 1:
                plf.colorbar(im, ticks=[sta_min, 0, sta_max], format='%.2f')
        plt.suptitle('{}\n{}\n'
                     '{:0>3}{:0>2} Rating: {}'.format(exp_name,
                                                      stimname + label,
                                                      clusters[j][0],
                                                      clusters[j][1],
                                                      clusters[j][2]))

        savepath = os.path.join(
            exp_dir, 'data_analysis', stimname, savefolder,
            '{:0>3}{:0>2}'.format(clusters[j][0], clusters[j][1]))

        os.makedirs(os.path.split(savepath)[0], exist_ok=True)

        plt.savefig(savepath + '.png', bbox_inches='tight')
        plt.close()
    print(f'Plotted checkerflicker STA for {stimname}')
Esempio n. 8
0
def ombtexturesta(exp, ombstimnr, maxframes=10000,
                  contrast_window=100, plot=False):
    """
    Calculates the spike-triggered average for the full texture for the OMB
    stimulus. Based on the maximum intensity pixel of the STAs, calculates
    the center of the receptive field and the contrast signal for this
    pixel throughout the stimulus; to be used as input for models.

    Parameters:
    --------
        exp:
            The experiment name
        ombstimulusnr:
            Number of the OMB stimulus in the experiment
        maxframes:
            Maximum number of frames that will be used, typically the
            array containing the contrast is very large and
            it is easy to fill the RAM. Refer to OMB.generatecontrast()
            documentation.
        contrast_window:
            Number of pixels to be used for the size of the texture.
            Measured in each direction starting from the center so
            a value of 100 will yield texture with size (201, 201, N)
            where N is the total number of frames.
        plot:
            If True draws an interactive plot for browsing all STAs,
            also marking the center pixels. Requires an interactive backend

    """
    st = OMB(exp, ombstimnr, maxframes=maxframes)
    st.clusterstats()

    contrast = st.generatecontrast(st.texpars.noiselim/2,
                                   window=contrast_window,
                                   pad_length=st.filter_length-1)

    contrast_avg = contrast.mean(axis=-1)

    RW = asc.rolling_window(contrast, st.filter_length, preserve_dim=False)

    all_spikes = np.zeros((st.nclusters, st.ntotal))
    for i in range(st.nclusters):
        all_spikes[i, :] = st.binnedspiketimes(i)

    texturestas = np.einsum('abcd,ec->eabd', RW, all_spikes)

    texturestas /= all_spikes.sum(axis=(-1))[:, np.newaxis,
                                             np.newaxis, np.newaxis]

    # Correct for the non-informative parts of the stimulus
    texturestas = texturestas - contrast_avg[None, ..., None]
    #%%
    if plot:
        fig_stas, _ = plf.multistabrowser(texturestas, cmap='Greys_r')

    texture_maxi = np.zeros((st.nclusters, 2), dtype=int)
    # Take the pixel with maximum intensity for contrast signal
    for i in range(st.nclusters):
        coords = np.unravel_index(np.argmax(np.abs(texturestas[i])),
                                  texturestas[i].shape)[:-1]
        texture_maxi[i, :] = coords
        if plot:
            ax = fig_stas.axes[i]
            # Coordinates need to be inverted for display
            ax.plot(*coords[::-1], 'r+', markersize=10, alpha=0.2)
    #%%
    contrast_signals = np.empty((st.nclusters, st.ntotal))
    # Calculate the time course of the center(maximal pixel of texture STAs
    stas_center = np.zeros((st.nclusters, st.filter_length))
    for i in range(st.nclusters):
        coords = texture_maxi[i, :]
        # Calculate the contrast signal that can be used for GQM
        # Cut the extra part at the beginning that was added by generatecontrast
        contrast_signals[i, :] = contrast[coords[0], coords[1],
                                          st.filter_length-1:]
        stas_center[i] = texturestas[i, coords[0], coords[1], :]

    stas_center_norm = asc.normalize(stas_center)

    fig_contrast, axes = plt.subplots(*plf.numsubplots(st.nclusters), sharey=True)
    for i, ax in enumerate(axes.ravel()):
        if i < st.nclusters:
            ax.plot(stas_center_norm[i, :])

    savepath = os.path.join(st.exp_dir, 'data_analysis', st.stimname)
    savefname = f'{st.stimnr}_texturesta'
    if not maxframes:
        maxframes = st.ntotal
    savefname += f'_{maxframes}fr'

    plt.ylim([np.nanmin(stas_center_norm), np.nanmax(stas_center_norm)])
    fig_contrast.suptitle('Time course of center pixel of texture STAs')
    fig_contrast.savefig(os.path.join(savepath, 'texturestas.svg'))

    # Do not save the contrast signal because it is ~6GB for 20000 frames of recording
    keystosave = ['texturestas', 'contrast_avg', 'stas_center',
                  'stas_center_norm', 'contrast_signals', 'texture_maxi',
                  'maxframes', 'contrast_window']
    datadict = {}
    for key in keystosave:
        datadict[key] = locals()[key]

    np.savez(os.path.join(savepath, savefname), **datadict)
    if plot:
        return fig_stas
Esempio n. 9
0
def multistabrowser(stas, frame_duration=None, cmap=None, centerzero=True):
    """
    Returns an interactive plot to browse multiple spatiotemporal
    STAs at the same time. Requires an interactive matplotlib backend.

    Parameters
    --------
    stas:
        Numpy array containing STAs. First dimension should index individual cells,
        last dimension should index time.
    frame_duration:
      Time between each frame. (optional)
    cmap:
      Colormap to use.
    centerzero:
      Whether to center the colormap around zero for diverging colormaps.

    Example
    ------
    >>> print(stas.shape) # (nrcells, xpixels, ypixels, time)
    (36, 75, 100, 40)
    >>> fig, slider = stabrowser(stas, frame_duration=1/60)

    Notes
    -----
    When calling the function, the slider is returned to prevent the reference
    to it getting destroyed and to keep it interactive.
    The dummy variable `_` can also be used.
    """
    interactive_backends = ['Qt', 'Tk']
    backend = mpl.get_backend()
    if not backend[:2] in interactive_backends:
        raise ValueError('Switch to an interactive backend (e.g. Qt) to see'
                         ' the animation.')

    if isinstance(stas, list):
        stas = np.array(stas)

    if cmap is None:
        cmap = iof.config('colormap')
    if centerzero:
        vmax = absmax(stas)
        vmin = absmin(stas)
    else:
        vmax, vmin = stas.max(), stas.min()

    imshowkwargs = dict(cmap=cmap, vmax=vmax, vmin=vmin)

    rows, cols = plf.numsubplots(stas.shape[0])
    fig, axes = plt.subplots(rows, cols, sharex=True, sharey=True)

    initial_frame = 5

    axsl = fig.add_axes([0.25, 0.05, 0.65, 0.03])
    # For the slider to remain interactive, a reference to it should
    # be kept, so it is returned by the function
    slider_t = Slider(axsl,
                      'Frame before spike',
                      0,
                      stas.shape[-1] - 1,
                      valinit=initial_frame,
                      valstep=1,
                      valfmt='%2.0f')

    def update(frame):
        frame = int(frame)
        for i in range(rows):
            for j in range(cols):
                im = axes[i, j].get_images()[0]
                im.set_data(stas[i * rows + j, ..., frame])
        if frame_duration is not None:
            fig.suptitle(f'{frame*frame_duration*1000:4.0f} ms')
        fig.canvas.draw_idle()

    slider_t.on_changed(update)

    for i in range(rows):
        for j in range(cols):
            ax = axes[i, j]
            ax.imshow(stas[i * rows + j, ..., initial_frame], **imshowkwargs)
            ax.set_axis_off()
    plt.tight_layout()
    plt.subplots_adjust(wspace=.01, hspace=.01)
    return fig, slider_t