Пример #1
0
                ax.yaxis.set_ticklabels(['-π', 0, 'π'])
                ax.xaxis.set_ticklabels([])
            elif row % nsubrows == nsubrows - 1:
                im = ax.imshow(cells[mode_i, :],
                               cmap='RdBu_r',
                               vmin=asc.absmin(cells),
                               vmax=asc.absmax(cells),
                               aspect='auto',
                               interpolation='nearest')
                ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells.shape[-1])
                ax.xaxis.set_ticklabels(
                    np.round((ax.get_xticks() * st.frame_duration), 1))
                ax.set_xlabel('Time [s]')
                if col == 0: ax.set_ylabel('Cells')
            if mode_i == n_components - 1:
                plf.colorbar(im)
    # fig.tight_layout()
    fig.subplots_adjust(wspace=0.3)
    fig.savefig(
        f'/Users/ycan/Downloads/2020-05-27_meeting/cca_DScells_shuffled.pdf')
    plt.show()

    #%%

    # im = plt.imshow(cells, cmap='RdBu_r', vmin=asc.absmin(cells), vmax=asc.absmax(cells))
    # plt.ylabel('Components')
    # plt.xlabel('Cells')
    # plf.colorbar(im)
    # plt.show()

    plt.plot(cca.coef_[:, 0], cca.coef_[:, 1], 'ko')
Пример #2
0
def cca_omb_components(exp: str,
                       stim_nr: int,
                       n_components: int = 6,
                       regularization=None,
                       filter_length=None,
                       cca_solver: str = 'macke',
                       maxframes=None,
                       shufflespikes: bool = False,
                       exclude_allzero_spike_rows: bool = True,
                       savedir: str = None,
                       savefig: bool = True,
                       sort_by_nspikes: bool = True,
                       select_cells: list = None,
                       plot_first_ncells: int = None,
                       whiten: bool = False):
    """
    Analyze OMB responses using cannonical correlation analysis and plot the results.

    Parameters
    ---
    n_components:
        Number of components that will be requested from the CCA anaylsis. More numbers mean
        the algortihm will stop at a later point. That means components of analyses with fewer
        n_components are going to be identical to the first n components of the higher-number
        component analyses.
    regularization:
        The regularization parameter to be passed onto rcca.CCA. Not relevant for macke
    filter_length:
        The length of the time window to be considered in the past for the stimulus and the responses.
        Can be different for stimulus and response, if a tuple is given.
    cca_solver:
        Which CCA solver to use. Options are `rcca` and `macke`(default)
    maxframes: int
        Number of frames to load in the the experiment object. Used to avoid memory and performance
        issues.
    shufflespikes: bool
        Whether to randomize the spikes, to validate the results
    exclude_allzero_spike_rows:
        Exclude all cells which have zero spikes for the duration of the stimulus.
    savedir: str
        Custom directory to save the figures and data files. If None, will be saved in the experiment
        directory under appropritate path.
    savefig: bool
        Whether to save the figures.
    sort_by_nspikes: bool
        Wheter to sort the cell weights array by the number of spikes during the stimulus.
    select_cells: list
       A list of indexes for the subset of cells to perform the analysis for.
    plot_first_ncells: int
        Number of cells to plot in the cell plots.
    """
    if regularization is None:
        regularization = 0

    st = OMB(exp, stim_nr, maxframes=maxframes)

    if filter_length is None:
        filter_length = st.filter_length

    if type(filter_length) is int:
        filter_length = (filter_length, filter_length)

    if type(savedir) is str:
        savedir = Path(savedir)

    if savedir is None:
        savedir = st.stim_dir / 'CCA'
        savedir.mkdir(exist_ok=True, parents=True)

    spikes = st.allspikes()
    nonzerospikerows = ~np.isclose(spikes.sum(axis=1), 0)
    # Set the mean to zero for spikes
    spikes -= spikes.mean(axis=1)[:, None]

    bgsteps = st.bgsteps

    if select_cells is not None:
        if type(select_cells) is not np.array:
            select_cells = np.array(select_cells)
        spikes = spikes[select_cells]
        st.nclusters = len(select_cells)
        # Convert to list for better string representation
        # np.array is printed as "array([....])"
        # with newline characters which is problematic in filenames
        select_cells = list(select_cells)

    # Exclude rows that have no spikes throughout
    if exclude_allzero_spike_rows:
        spikes = spikes[nonzerospikerows, :]

    nspikes_percell = spikes.sum(axis=1)

    if shufflespikes:
        spikes = spikeshuffler.shufflebyrow(spikes)

    figsavename = f'{n_components=}_{shufflespikes=}_{select_cells=}_{regularization=}_{filter_length=}_{whiten=}'
    # If the file length gets too long due to the list of selected cells, summarize it.
    if len(figsavename) > 200:
        figsavename = f'{n_components=}_{shufflespikes=}_select_cells={len(select_cells)}cells-index{select_cells[0]}to{select_cells[-1]}_{regularization=}_{filter_length=}_{whiten=}'

    #sp_train, sp_test, stim_train, stim_test = train_test_split(spikes, bgsteps)

    stimulus = mft.packdims(st.bgsteps, filter_length[0])
    spikes = mft.packdims(spikes, filter_length[1])

    if cca_solver == 'rcca':
        resp_comps, stim_comps, cancorrs = cca_rcca(spikes, stimulus,
                                                    filter_length,
                                                    n_components,
                                                    regularization, whiten)

        # cells = np.swapaxes(spikes_res, 1, 0)
        # cells = cells.reshape((n_components, st.nclusters, filter_length[1]))
        # motionfilt_x = cca.ws[1][:filter_length[0]].T
        # motionfilt_y = cca.ws[1][filter_length[0]:].T
    elif cca_solver == 'macke':
        resp_comps, stim_comps, cancorrs = cca_macke(spikes, stimulus,
                                                     filter_length,
                                                     n_components)

    nsp_argsorted = np.argsort(nspikes_percell)
    resp_comps_sorted_nsp = resp_comps[:, nsp_argsorted, :]

    if sort_by_nspikes:
        resp_comps_toplot = resp_comps_sorted_nsp
    else:
        resp_comps_toplot = resp_comps

    if plot_first_ncells is not None:
        resp_comps_toplot = resp_comps_toplot[:, :plot_first_ncells, ...]

    motionfilt_r, motionfilt_theta = mft.cart2pol(stim_comps[:, 0, :],
                                                  stim_comps[:, 1, :])
    #%%
    nrows, ncols = plf.numsubplots(n_components)
    fig_cells, axes_cells = plt.subplots(nrows, ncols, figsize=(10, 10))

    for i in range(n_components):
        ax = axes_cells.flat[i]
        im = ax.imshow(resp_comps[i, :],
                       cmap='RdBu_r',
                       vmin=asc.absmin(resp_comps),
                       vmax=asc.absmax(resp_comps),
                       aspect='auto',
                       interpolation='nearest')
        ax.set_title(f'{i}')
    fig_cells.suptitle(f'Cells default order {shufflespikes=}')
    if savefig:
        fig_cells.savefig(savedir / f'{figsavename}_cells_default_order.pdf')
    plt.close(fig_cells)

    nsubplots = plf.numsubplots(n_components)
    height_list = [1, 1, 1, 3]  # ratios of the plots in each component

    # Create a time vector for the stimulus plots
    t_stim = -np.arange(0, filter_length[0] * st.frame_duration,
                        st.frame_duration)[::-1] * 1000
    t_response = -np.arange(0, filter_length[1] * st.frame_duration,
                            st.frame_duration)[::-1] * 1000
    xtick_loc_params = dict(nbins=4, steps=[2, 5, 10], integer=True)

    nsubrows = len(height_list)
    height_ratios = nsubplots[0] * height_list
    fig, axes = plt.subplots(nrows=nsubplots[0] * nsubrows,
                             ncols=nsubplots[1],
                             gridspec_kw={'height_ratios': height_ratios},
                             figsize=(11, 10))

    for row, ax_row in enumerate(axes):
        for col, ax in enumerate(ax_row):
            mode_i = int(row / nsubrows) * nsubplots[1] + col
            # ax.text(0.5, 0.5, f'{mode_i}')
            ax.set_yticks([])
            # Plot motion filters
            if row % nsubrows == 0:

                ax.plot(t_stim,
                        stim_comps[mode_i, 0, :],
                        marker='o',
                        markersize=1)
                ax.plot(t_stim,
                        stim_comps[mode_i, 1, :],
                        marker='o',
                        markersize=1)
                if col == 0:
                    ax.set_ylabel('Motion',
                                  rotation=0,
                                  ha='right',
                                  va='center')
                ax.set_ylim(stim_comps.min(), stim_comps.max())

                # Draw a horizontal line for zero and prevent rescaling of x-axis
                xlims = ax.get_xlim()
                ax.hlines(0,
                          *ax.get_xlim(),
                          colors='k',
                          linestyles='dashed',
                          alpha=0.3)
                ax.set_xlim(*xlims)

                # ax.set_title(f'Component {mode_i}', fontweight='bold')

                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))

                if not mode_i == 0 or filter_length[0] == filter_length[1]:
                    ax.xaxis.set_ticklabels([])
                else:
                    ax.tick_params(axis='x', labelsize=8)

            # Plot magnitude of motion
            elif row % nsubrows == 1:
                ax.plot(t_stim,
                        motionfilt_r[mode_i, :],
                        color='k',
                        marker='o',
                        markersize=1)
                if col == 0:
                    ax.set_ylabel('Magnitude',
                                  rotation=0,
                                  ha='right',
                                  va='center')
                ax.set_ylim(motionfilt_r.min(), motionfilt_r.max())
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
            # Plot direction of motion
            elif row % nsubrows == 2:
                ax.plot(t_stim,
                        motionfilt_theta[mode_i, :],
                        color='r',
                        marker='o',
                        markersize=1)
                if mode_i == 0:
                    ax.yaxis.set_ticks([-np.pi, 0, np.pi])
                    ax.yaxis.set_ticklabels(['-π', 0, 'π'])
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
            # Plot cell weights
            elif row % nsubrows == nsubrows - 1:
                im = ax.imshow(resp_comps_toplot[mode_i, :],
                               cmap='RdBu_r',
                               vmin=asc.absmin(resp_comps),
                               vmax=asc.absmax(resp_comps),
                               aspect='auto',
                               interpolation='nearest',
                               extent=[
                                   t_response[0], t_response[-1], 0,
                                   resp_comps_toplot.shape[1]
                               ])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
                if row == axes.shape[0] - 1:
                    ax.set_xlabel('Time before spike [ms]')
                    # ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells_toplot.shape[-1])
                    # ax.xaxis.set_ticklabels(-np.round((ax.get_xticks()*st.frame_duration), 2)[::-1])
                else:
                    ax.xaxis.set_ticklabels([])

                plf.integerticks(ax, 5, which='y')
                if col == 0:
                    ax.set_ylabel(
                        f'Cells\n{"(sorted nsp)"*sort_by_nspikes}\n{("(first " + str(plot_first_ncells)+ " cells)")*(type(plot_first_ncells) is int) }',
                        rotation=0,
                        ha='right',
                        va='center')
                else:
                    ax.yaxis.set_ticklabels([])
                if mode_i == n_components - 1:
                    plf.colorbar(im)
            # Add ticks on the right side of the plots
            if col == nsubplots[1] - 1 and row % nsubrows != nsubrows - 1:
                plf.integerticks(ax, 3, which='y')
                ax.yaxis.tick_right()

    fig.suptitle(
        f'CCA components of {st.exp_foldername}\n{shufflespikes=} {n_components=}\n{sort_by_nspikes=}\n'
        + f'{select_cells=} {regularization=} {filter_length=}')
    fig.subplots_adjust(wspace=0.1, hspace=0.3)
    if savefig:
        fig.savefig(savedir / f'{figsavename}_cellsandcomponents.pdf')
    # plt.show()
    plt.close(fig)

    #%%
    fig_corrs = plt.figure()
    plt.plot(cancorrs, 'ko')
    # plt.ylim([0.17, 0.24])
    plt.xlabel('Component index')
    plt.ylabel('Correlation')
    plt.title(f'Cannonical correlations {shufflespikes=}')
    if savefig:
        fig_corrs.savefig(savedir / f'{figsavename}_correlation_coeffs.pdf')
    # plt.show()
    plt.close(fig_corrs)

    fig_nlt, axes_nlt = plt.subplots(nrows, ncols, figsize=(10, 10))

    stim_comps_flatter = stim_comps[:n_components].reshape(
        (n_components, 2 * filter_length[0])).T
    resp_comps_flatter = resp_comps[:n_components].reshape(
        (n_components, resp_comps.shape[1] * filter_length[1])).T

    # from IPython.core.debugger import Pdb; ipdb=Pdb(); ipdb.set_trace()
    # Reshape to perform the convolution as a matrix multiplication
    generator_stim = stimulus @ stim_comps_flatter
    generator_resp = spikes @ resp_comps_flatter

    for i, ax in enumerate(axes_nlt.flatten()):

        nonlinearity, bins = nlt.calc_nonlin(generator_resp[:, i],
                                             generator_stim[:, i])
        # ax.scatter(generator_stim, generator_resp, s=1, alpha=0.5, facecolor='grey')
        ax.plot(bins, nonlinearity, 'k')
        if i == 0:
            all_nonlinearities = np.empty((n_components, *nonlinearity.shape))
            all_bins = np.empty((n_components, *bins.shape))
        all_nonlinearities[i, ...] = nonlinearity
        all_bins[i, ...] = bins

    nlt_xlims = []
    nlt_ylims = []
    for i, ax in enumerate(axes_nlt.flatten()):
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        nlt_xlims.extend(xlim)
        nlt_ylims.extend(ylim)
    nlt_maxx, nlt_minx = max(nlt_xlims), min(nlt_xlims)
    nlt_maxy, nlt_miny = max(nlt_ylims), min(nlt_ylims)

    for i, ax in enumerate(axes_nlt.flatten()):
        ax.set_xlim([nlt_minx, nlt_maxx])
        ax.set_ylim([nlt_miny, nlt_maxy])

    for i, axes_row in enumerate(axes_nlt):
        for j, ax in enumerate(axes_row):
            if i == nrows - 1:
                ax.set_xlabel('Generator (motion)')
            if j == 0:
                ax.set_ylabel('Generator (cells)')
            else:
                ax.yaxis.set_ticklabels([])
            ax.set_xlim([nlt_minx, nlt_maxx])
            ax.set_ylim([nlt_miny, nlt_maxy])

    fig_nlt.suptitle(f'Nonlinearities\n{figsavename}')
    if savefig:
        fig_nlt.savefig(savedir / f'{figsavename}_nonlinearity.png')
    plt.close(fig_nlt)
    keystosave = [
        'n_components', 'resp_comps', 'stim_comps', 'motionfilt_r',
        'motionfilt_theta', 'resp_comps_sorted_nsp', 'select_cells',
        'regularization', 'filter_length', 'all_nonlinearities', 'all_bins',
        'cca_solver'
    ]
    datadict = dict()
    for key in keystosave:
        datadict[key] = locals()[key]
    np.savez(savedir / figsavename, **datadict)
Пример #3
0
def randomizestripes(label, exp_name='20180124', stim_nrs=6):
    exp_dir = iof.exp_dir_fixer(exp_name)

    if isinstance(stim_nrs, int):
        stim_nrs = [stim_nrs]

    for stim_nr in stim_nrs:
        stimname = iof.getstimname(exp_name, stim_nr)

        clusters, metadata = asc.read_spikesheet(exp_dir)

        parameters = asc.read_parameters(exp_dir, stim_nr)

        scr_width = metadata['screen_width']
        px_size = metadata['pixel_size(um)']

        stx_w = parameters['stixelwidth']
        stx_h = parameters['stixelheight']

        if (stx_h/stx_w) < 2:
            raise ValueError('Make sure the stimulus is stripeflicker.')

        sy = scr_width/stx_w
#        sy = sy*4
        sy = int(sy)

        nblinks = parameters['Nblinks']
        try:
            bw = parameters['blackwhite']
        except KeyError:
            bw = False

        try:
            seed = parameters['seed']
            initialseed = parameters['seed']
        except KeyError:
            seed = -10000
            initialseed = -10000

        if nblinks == 1:
            ft_on, ft_off = asc.readframetimes(exp_dir, stim_nr,
                                               returnoffsets=True)
            # Initialize empty array twice the size of one of them, assign
            # value from on or off to every other element.
            frametimings = np.empty(ft_on.shape[0]*2, dtype=float)
            frametimings[::2] = ft_on
            frametimings[1::2] = ft_off
            # Set filter length so that temporal filter is ~600 ms.
            # The unit here is number of frames.
            filter_length = 40
        elif nblinks == 2:
            frametimings = asc.readframetimes(exp_dir, stim_nr)
            filter_length = 20
        else:
            raise ValueError('Unexpected value for nblinks.')

        # Omit everything that happens before the first 10 seconds
        cut_time = 10

        frame_duration = np.average(np.ediff1d(frametimings))
        total_frames = int(frametimings.shape[0]/4)

        all_spiketimes = []
        # Store spike triggered averages in a list containing correct
        # shaped arrays
        stas = []

        for i in range(len(clusters[:, 0])):
            spikes_orig = asc.read_raster(exp_dir, stim_nr,
                                         clusters[i, 0], clusters[i, 1])
            spikesneeded = spikes_orig.shape[0]*1000

            spiketimes = np.random.random_sample(spikesneeded)*spikes_orig.max()
            spiketimes = np.sort(spiketimes)
            spikes = asc.binspikes(spiketimes, frametimings)
            all_spiketimes.append(spikes)
            stas.append(np.zeros((sy, filter_length)))

        if bw:
            randnrs, seed = randpy.ran1(seed, sy*total_frames)
#            randnrs = mersennetw(sy*total_frames, seed1=seed)
            randnrs = [1 if i > .5 else -1 for i in randnrs]
        else:
            randnrs, seed = randpy.gasdev(seed, sy*total_frames)

        stimulus = np.reshape(randnrs, (sy, total_frames), order='F')
        del randnrs

        for k in range(filter_length, total_frames-filter_length+1):
            stim_small = stimulus[:, k-filter_length+1:k+1][:, ::-1]
            for j in range(clusters.shape[0]):
                spikes = all_spiketimes[j]
                if spikes[k] != 0 and frametimings[k]>cut_time:
                    stas[j] += spikes[k]*stim_small

        max_inds = []

        spikenrs = np.array([spikearr.sum() for spikearr in all_spiketimes])

        quals = np.array([])

        for i in range(clusters.shape[0]):
            stas[i] = stas[i]/spikenrs[i]
            # Find the pixel with largest absolute value
            max_i = np.squeeze(np.where(np.abs(stas[i])
                                        == np.max(np.abs(stas[i]))))
            # If there are multiple pixels with largest value,
            # take the first one.
            if max_i.shape != (2,):
                try:
                    max_i = max_i[:, 0]
                # If max_i cannot be found just set it to zeros.
                except IndexError:
                    max_i = np.array([0, 0])

            max_inds.append(max_i)

            quals = np.append(quals, asc.staquality(stas[i]))

#        savefname = str(stim_nr)+'_data'
#        savepath = pjoin(exp_dir, 'data_analysis', stimname)
#
#        exp_name = os.path.split(exp_dir)[-1]
#
#        if not os.path.isdir(savepath):
#            os.makedirs(savepath, exist_ok=True)
#        savepath = os.path.join(savepath, savefname)
#
#        keystosave = ['stas', 'max_inds', 'clusters', 'sy',
#                      'frame_duration', 'all_spiketimes', 'stimname',
#                      'total_frames', 'stx_w', 'spikenrs', 'bw',
#                      'quals', 'nblinks', 'filter_length', 'exp_name']
#        data_in_dict = {}
#        for key in keystosave:
#            data_in_dict[key] = locals()[key]
#
#        np.savez(savepath, **data_in_dict)
#        print(f'Analysis of {stimname} completed.')


        clusterids = plf.clusters_to_ids(clusters)

#        assert(initialseed.ty)
        correction = corrector(sy, total_frames, filter_length, initialseed)
        correction = np.outer(correction, np.ones(filter_length))

        t = np.arange(filter_length)*frame_duration*1000
        vscale = int(stas[0].shape[0] * stx_w*px_size/1000)
        for i in range(clusters.shape[0]):
            sta = stas[i]-correction

            vmax = 0.03
            vmin = -vmax
            plt.figure(figsize=(6, 15))
            ax = plt.subplot(111)
            im = ax.imshow(sta, cmap='RdBu', vmin=vmin, vmax=vmax,
                           extent=[0, t[-1], -vscale, vscale], aspect='auto')
            plt.xlabel('Time [ms]')
            plt.ylabel('Distance [mm]')

            plf.spineless(ax)
            plf.colorbar(im, ticks=[vmin, 0, vmax], format='%.2f', size='2%')
            plt.suptitle('{}\n{}\n'
                         '{} Rating: {}\n'
                         'nrofspikes {:5.0f}'.format(exp_name,
                                                       stimname,
                                                       clusterids[i],
                                                       clusters[i][2],
                                                       spikenrs[i]))
            plt.subplots_adjust(top=.90)
            savepath = os.path.join(exp_dir, 'data_analysis',
                                    stimname, 'STAs_randomized')
            svgpath = pjoin(savepath, label)
            if not os.path.isdir(svgpath):
                os.makedirs(svgpath, exist_ok=True)
            plt.savefig(os.path.join(svgpath, clusterids[i]+'.svg'),
                        bbox_inches='tight')
            plt.close()

    os.system(f"convert -delay 25 {svgpath}/*svg {savepath}/animated_{label}.gif")
Пример #4
0
    'vmin': 0,
    'vmax': allspikes.max(),
    #                'cmap':'Greys_r',
    'cmap': 'magma',
}
plt.figure()
ax_rsp = plt.subplot(311)
ax_rsp.matshow(allspikes[:, sl], **imshowkwargs)

ax_psp = plt.subplot(312, sharex=ax_rsp, sharey=ax_rsp)
im = ax_psp.matshow(predspikes[:, sl], **imshowkwargs)
im.cmap.set_over('r')
im.cmap.set_under('k')
ax_stim = plt.subplot(313, sharex=ax_rsp)
ax_stim.plot(stimulus[sl], lw=.8)
plf.colorbar(im, ax=ax_stim, size='1%')
plt.tight_layout()
plt.show()

#%%
avgspikes = allspikes.mean(axis=1)
avgspikes_pred = predspikes.mean(axis=1)

plt.figure(figsize=(8.5, 5.5))
ax1 = plt.subplot(121)
#ax1 = plt.gca()
ax1.scatter(avgspikes, avgspikes_pred)
ax1.set_xlabel('Avg spike nr per time bin')
ax1.set_ylabel('Predicted spike nr per time bin')
ax1.set_xlim([-.05, .9])
ax1.set_ylim([-.05, .9])
Пример #5
0
        vmin = -vmax
        ax = plt.subplot(2, 2, j+1)
        plt.title('STA quality: {:4.2f}'.format(quals[i]))

        ax.set_aspect('equal')

        im = ax.imshow(sta, cmap='RdBu', vmin=vmin, vmax=vmax,
                       extent=[0, t[-1], -vscale, vscale], aspect='auto')
        if j >= 2:
            plt.xlabel('Time [s]\n\nFrame duration: {:2.1f}'
                       'ms'.format(frame_duration*1000))
        else:
            ax.axes.get_xaxis().set_visible(False)
        if j == 0:
            plt.ylabel('B&W\n\nDistance[µm]')
        if j == 2:
            plt.ylabel('Gaussian\n\nDistance[µm]')
        if j % 2 == 1:
            ax.axes.get_yaxis().set_visible(False)

        plf.spineless(ax)
        plf.colorbar(im, ticks=[vmin, 0, vmax], format='%.2f', size='2%')
    plt.subplots_adjust(hspace=.4, wspace=.5)
    plt.suptitle('{}\n{}'.format(exp_name, clusterids[i]))

    savepath = os.path.join(exp_dir, 'data_analysis', 'allstripes')
    if not os.path.isdir(savepath):
        os.makedirs(savepath, exist_ok=True)
    plt.savefig(os.path.join(savepath, clusterids[i]+'.svg'))
    plt.close()
Пример #6
0
def plotstripestas(exp_name, stim_nrs):
    """
    Plot and save all the STAs from multiple stripe flicker stimuli.
    """
    exp_dir = iof.exp_dir_fixer(exp_name)

    _, metadata = asc.read_spikesheet(exp_dir)
    px_size = metadata['pixel_size(um)']

    if isinstance(stim_nrs, int):
        stim_nrs = [stim_nrs]
    elif len(stim_nrs) == 0:
        return

    for stim_nr in stim_nrs:
        data = iof.load(exp_name, stim_nr)

        clusters = data['clusters']
        stas = data['stas']
        filter_length = data['filter_length']
        stx_w = data['stx_w']
        exp_name = data['exp_name']
        stimname = data['stimname']
        frame_duration = data['frame_duration']
        quals = data['quals']

        clusterids = plf.clusters_to_ids(clusters)

        # Determine frame size so that the total frame covers
        # an area large enough i.e. 2*700um
        t = np.arange(filter_length) * frame_duration * 1000
        vscale = int(stas[0].shape[0] * stx_w * px_size / 1000)
        for i in range(clusters.shape[0]):
            sta = stas[i]

            vmax = np.max(np.abs(sta))
            vmin = -vmax
            plt.figure(figsize=(6, 15))
            ax = plt.subplot(111)
            im = ax.imshow(sta,
                           cmap='RdBu',
                           vmin=vmin,
                           vmax=vmax,
                           extent=[0, t[-1], -vscale, vscale],
                           aspect='auto')
            plt.xlabel('Time [ms]')
            plt.ylabel('Distance [mm]')

            plf.spineless(ax)
            plf.colorbar(im, ticks=[vmin, 0, vmax], format='%.2f', size='2%')
            plt.suptitle(f'{exp_name}\n{stimname}\n'
                         f'{clusterids[i]} Rating: {clusters[i][2]}\n'
                         f'STA quality: {quals[i]:4.2f}')
            plt.subplots_adjust(top=.90)
            savepath = os.path.join(exp_dir, 'data_analysis', stimname, 'STAs')
            if not os.path.isdir(savepath):
                os.makedirs(savepath, exist_ok=True)
            plt.savefig(os.path.join(savepath, clusterids[i] + '.svg'),
                        bbox_inches='tight')
            plt.close()
        print(f'Plotting of {stimname} completed.')
Пример #7
0
def plotcheckersvd(expname, stimnr, filename=None):
    """
    Plot the first two components of SVD analysis.
    """
    if filename:
        filename = str(filename)

    exp_dir = iof.exp_dir_fixer(expname)
    _, metadata = asc.read_spikesheet(exp_dir)
    px_size = metadata['pixel_size(um)']

    if not filename:
        savefolder = 'SVD'
        label = ''
    else:
        label = filename.strip('.npz')
        savefolder = 'SVD_' + label

    data = iof.load(expname, stimnr, filename)

    stas = data['stas']
    max_inds = data['max_inds']
    clusters = data['clusters']
    stx_h = data['stx_h']
    frame_duration = data['frame_duration']
    stimname = data['stimname']
    exp_name = data['exp_name']

    clusterids = plf.clusters_to_ids(clusters)

    # Determine frame size so that the total frame covers
    # an area large enough i.e. 2*700um
    f_size = int(700 / (stx_h * px_size))

    for i in range(clusters.shape[0]):
        sta = stas[i]
        max_i = max_inds[i]

        try:
            sta, max_i = msc.cut_around_center(sta, max_i, f_size=f_size)
        except ValueError:
            continue
        fit_frame = sta[:, :, max_i[2]]

        try:
            sp1, sp2, t1, t2, _, _ = msc.svd(sta)
        # If the STA is noisy (msc.cut_around_center produces an empty array)
        # SVD cannot be calculated, in this case we skip that cluster.
        except np.linalg.LinAlgError:
            continue

        plotthese = [fit_frame, sp1, sp2]

        plt.figure(dpi=200)
        plt.suptitle(f'{exp_name}\n{stimname}\n{clusterids[i]}')
        rows = 2
        cols = 3

        vmax = np.max(np.abs([sp1, sp2]))
        vmin = -vmax

        for j in range(len(plotthese)):
            ax = plt.subplot(rows, cols, j + 1)
            im = plt.imshow(plotthese[j],
                            vmin=vmin,
                            vmax=vmax,
                            cmap=iof.config('colormap'))
            ax.set_aspect('equal')
            plt.xticks([])
            plt.yticks([])
            for child in ax.get_children():
                if isinstance(child, matplotlib.spines.Spine):
                    child.set_color('C{}'.format(j % 3))
                    child.set_linewidth(2)
            if j == 0:
                plt.title('center px')
            elif j == 1:
                plt.title('SVD spatial 1')
            elif j == 2:
                plt.title('SVD spatial 2')
                plf.colorbar(im, ticks=[vmin, 0, vmax], format='%.2f')
                barsize = 100 / (stx_h * px_size)
                scalebar = AnchoredSizeBar(ax.transData,
                                           barsize,
                                           '100 µm',
                                           'lower left',
                                           pad=0,
                                           color='k',
                                           frameon=False,
                                           size_vertical=.3)
                ax.add_artist(scalebar)

        t = np.arange(sta.shape[-1]) * frame_duration * 1000
        plt.subplots_adjust(wspace=0.3, hspace=0)
        ax = plt.subplot(rows, 1, 2)
        plt.plot(t, sta[max_i[0], max_i[1], :], label='center px')
        plt.plot(t, t1, label='Temporal 1')
        plt.plot(t, t2, label='Temporal 2')
        plt.xlabel('Time[ms]')
        plf.spineless(ax, 'trlb')  # Turn off spines using custom function

        plotpath = os.path.join(exp_dir, 'data_analysis', stimname, savefolder)
        if not os.path.isdir(plotpath):
            os.makedirs(plotpath, exist_ok=True)
        plt.savefig(os.path.join(plotpath, clusterids[i] + '.svg'), dpi=300)
        plt.close()
    print(f'Plotted checkerflicker SVD for {stimname}')
Пример #8
0
def plot_checker_stas(exp_name, stim_nr, filename=None):
    """
    Plot and save all STAs from checkerflicker analysis. The plots
    will be saved in a new folder called STAs under the data analysis
    path of the stimulus.

    <exp_dir>/data_analysis/<stim_nr>_*/<stim_nr>_data.h5 file is
    used by default. If a different file is to be used, filename
    should be supplied.
    """

    from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar

    exp_dir = iof.exp_dir_fixer(exp_name)
    stim_nr = str(stim_nr)
    if filename:
        filename = str(filename)

    _, metadata = asc.read_spikesheet(exp_dir)
    px_size = metadata['pixel_size(um)']

    if not filename:
        savefolder = 'STAs'
        label = ''
    else:
        label = filename.strip('.npz')
        savefolder = 'STAs_' + label

    data = iof.load(exp_name, stim_nr, fname=filename)

    clusters = data['clusters']
    stas = data['stas']
    filter_length = data['filter_length']
    stx_h = data['stx_h']
    exp_name = data['exp_name']
    stimname = data['stimname']

    for j in range(clusters.shape[0]):
        a = stas[j]
        subplot_arr = plf.numsubplots(filter_length)
        sta_max = np.max(np.abs([np.max(a), np.min(a)]))
        sta_min = -sta_max
        plt.figure(dpi=250)
        for i in range(filter_length):
            ax = plt.subplot(subplot_arr[0], subplot_arr[1], i + 1)
            im = ax.imshow(a[:, :, i],
                           vmin=sta_min,
                           vmax=sta_max,
                           cmap=iof.config('colormap'))
            ax.set_aspect('equal')
            plt.axis('off')
            if i == 0:
                scalebar = AnchoredSizeBar(ax.transData,
                                           10,
                                           '{} µm'.format(10 * stx_h *
                                                          px_size),
                                           'lower left',
                                           pad=0,
                                           color='k',
                                           frameon=False,
                                           size_vertical=1)
                ax.add_artist(scalebar)
            if i == filter_length - 1:
                plf.colorbar(im, ticks=[sta_min, 0, sta_max], format='%.2f')
        plt.suptitle('{}\n{}\n'
                     '{:0>3}{:0>2} Rating: {}'.format(exp_name,
                                                      stimname + label,
                                                      clusters[j][0],
                                                      clusters[j][1],
                                                      clusters[j][2]))

        savepath = os.path.join(
            exp_dir, 'data_analysis', stimname, savefolder,
            '{:0>3}{:0>2}'.format(clusters[j][0], clusters[j][1]))

        os.makedirs(os.path.split(savepath)[0], exist_ok=True)

        plt.savefig(savepath + '.png', bbox_inches='tight')
        plt.close()
    print(f'Plotted checkerflicker STA for {stimname}')