Example #1
0
def chkmax2ombcoord(cell_i, exp, omb_stimnr, chk_stimnr):
    """
    Starting from the index of the cell, return the coordinates of the
    maximum pixel of checkerflicker STA in OMB coordinates.

    Parameters
    ---------
    cell_i
        Cell index
    """
    data = iof.load(exp, chk_stimnr)
    maxinds = np.array(data['max_inds'])

    coord = coord_chk2omb(maxinds[cell_i][:-1], exp, omb_stimnr, chk_stimnr)
    return coord
Example #2
0
inner_b=2
outer_b=4

exp_name = '20180207'
stim_nr = 11


exp_dir = iof.exp_dir_fixer(exp_name)
stim_nr = str(stim_nr)

savefolder = 'surroundplots'

_, metadata = asc.read_spikesheet(exp_name)
px_size = metadata['pixel_size(um)']

data = iof.load(exp_name, stim_nr)

clusters = data['clusters']
stas = data['stas']
stx_h = data['stx_h']
exp_name = data['exp_name']
stimname = data['stimname']
max_inds = data['max_inds']
frame_duration = data['frame_duration']
filter_length = data['filter_length']
quals = data['quals'][-1, :]

spikenrs = np.array([a.sum() for a in data['all_spiketimes']])

choose = [33]
clusters = clusters[choose]
Example #3
0
@author: ycan
"""

import iofuncs as iof
import matplotlib.pyplot as plt
import plotfuncs as plf

exps = [('V', 10), ('Kara', 5), ('20171116', 6), ('20171122', 6),
        ('20171122', 7)]

all_quals = []

for i in range(len(exps)):
    exp = exps[i]
    data = iof.load(*exp)

    quals = data['quals'][-1, :]

    all_quals.append(quals)

ax = plt.subplot(111)
for j in range(len(all_quals)):
    plt.scatter(all_quals[j], [j] * len(all_quals[j]))
    plt.text(50, j, str(exps[j]), fontsize=8)
ax.set_yticks([])
plt.ylabel('Experiment')
plt.xlabel('Center px z-score')
plt.title('Distribution of STA qualities')
plf.spineless(ax)
plt.show()
colorcategories = data['colorcategories']
colorlabels = data['colorlabels']

toplot = [
    ['20180124', '02001'],  # Increasing bias
    ['20180207', '03503'],  # Decreasing bias
]

for j, (exp_name, clustertoplot) in enumerate(toplot):
    if '20180124' in exp_name or '20180207' in exp_name:
        onoffs = [3, 8]
    elif '20180118' in exp_name:
        onoffs = [3, 10]

    for i, (cond, stim) in enumerate(zip(['M', 'P'], onoffs)):
        expdata = iof.load(exp_name, stim)
        clusters = expdata['clusters']
        preframedur = expdata['preframe_duration']
        stimdur = expdata['stim_duration']
        clusterids = plf.clusters_to_ids(clusters)
        index = [i for i, cl in enumerate(clusterids)
                 if cl == clustertoplot][0]

        fr = expdata['all_frs'][index]
        t = expdata['t']
        baselines = expdata['baselines'][index]

        plotind = [1, 3, 5, 7][i + 2 * j]
        ax = plt.subplot(4, 2, plotind)
        ax.plot(t, fr, 'k', linewidth=.5)
        plf.spineless(ax)
Example #5
0
def stripestim(exp_name):
    if '20180124' in exp_name or '20180207' in exp_name:
        stripeflicker = [6, 12]
    elif '20180118' in exp_name:
        stripeflicker = [7, 14]
    return stripeflicker


exps = ['20180118', '20180124', '20180207']

data = np.load('/home/ycan/Documents/thesis/analysis_auxillary_files/'
               'thesis_csiplotting.npz')
include = data['include']
cells = data['cells']
groups = data['groups']

all_fits = np.empty((*cells.shape, 73))

for exp in exps:
    stim = stripestim(exp)
    fits_m = np.array(iof.load(exp, stim[0])['fits'])
    fits_p = np.array(iof.load(exp, stim[1])['fits'])

p = plf.numsubplots(nrcells)
axes = plt.subplots(*p)[1].ravel()
for i in range(nrcells):
    ax = axes[i]
    ax.plot(fits_m[i, :])
    ax.plot(fits_p[i, :])
    plf.spineless(ax)
    ax.set_axis_off()
Example #6
0
    return res


#%%
# If the script is being imported from elsewhere to use the functions, do not run the simulation
if __name__ == '__main__':

    # Using real filter and stimulus
    if True:
        import genlinmod as glm
        import iofuncs as iof
        import analysis_scripts as asc

        expstim = ('20180802', 1)

        data = iof.load(*expstim)
        _, metadata = asc.read_spikesheet(expstim[0])
        sta = data['stas'][3]

        filter_length = sta.shape[0]
        frame_rate = metadata['refresh_rate']
        time_res = 1 / frame_rate
        tstop = data['total_frames'] * time_res
        t = np.arange(0, tstop, time_res)
        k_in = sta

        stim = glm.loadstim(*expstim)

    else:
        filter_length = 40
        frame_rate = 60
Example #7
0
#def conv(k, x):
#    return np.convolve(k, x, 'full')[:-k.shape[0]+1]


def normalizestas(stas):
    stas = np.array(stas)
    b = np.abs(stas).max(axis=1)
    stas_normalized = stas / b.repeat(stas.shape[1]).reshape(stas.shape)
    return stas_normalized


#%%
exp_name = '20180802'
stim_nr = 1
data = iof.load(exp_name, stim_nr)
stimulus = glm.loadstim(exp_name, stim_nr)
clusters = data['clusters']
#%%
#stas = np.array(data['stas'])
#stas_normalized = np.abs(stas).max(axis=1)
#stas_normalized = a / stas_normalized.repeat(stas.shape[1]).reshape(stas.shape)
frametimes = asc.ft_nblinks(exp_name, stim_nr)[1]

#stas = normalizestas(data['stas'])
stas = np.array(data['stas'])

predstas = np.zeros(stas.shape)
predmus = np.zeros(stas.shape[0])
start = dt.datetime.now()
Example #8
0
def allonoff(exp_name, stim_nrs):

    if isinstance(stim_nrs, int) or len(stim_nrs) <= 1:
        print('Multiple onoffsteps stimuli expected, '
              'allonoff analysis will be skipped.')
        return

    exp_dir = iof.exp_dir_fixer(exp_name)
    exp_name = os.path.split(exp_dir)[-1]

    for j, stim in enumerate(stim_nrs):
        data = iof.load(exp_name, stim)
        all_frs = data['all_frs']
        clusters = data['clusters']
        preframe_duration = data['preframe_duration']
        stim_duration = data['stim_duration']
        onoffbias = data['onoffbias']
        t = data['t']

        if j == 0:
            a = np.zeros((clusters.shape[0], t.shape[0], len(stim_nrs)))
            bias = np.zeros((clusters.shape[0], len(stim_nrs)))
        a[:, :, j] = np.array(all_frs)
        bias[:, j] = onoffbias

    plotpath = os.path.join(exp_dir, 'data_analysis', 'allonoff')
    clusterids = plf.clusters_to_ids(clusters)
    if not os.path.isdir(plotpath):
        os.makedirs(plotpath, exist_ok=True)

    for i in range(clusters.shape[0]):
        ax = plt.subplot(111)
        for j, stim in enumerate(stim_nrs):
            labeltxt = (
                iof.getstimname(exp_name, stim).replace('onoffsteps_', '') +
                f' Bias: {bias[i, j]:4.2f}')
            plt.plot(t, a[i, :, j], alpha=.5, label=labeltxt)
        plt.title(f'{exp_name}\n{clusterids[i]}')
        plt.legend()
        plf.spineless(ax)
        plf.drawonoff(ax, preframe_duration, stim_duration, h=.1)

        plt.savefig(os.path.join(plotpath, clusterids[i]) + '.svg',
                    format='svg',
                    dpi=300)
        plt.close()

    rows = len(stim_nrs)
    columns = 1
    _, axes = plt.subplots(rows, columns, sharex=True)
    colors = plt.get_cmap('tab10')

    for i, stim in enumerate(stim_nrs):
        ax = axes[i]
        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', category=RuntimeWarning)
            ax.hist(bias[:, i],
                    bins=20,
                    color=colors(i),
                    range=[-1, 1],
                    alpha=.5)

        ax.set_ylabel(
            iof.getstimname(exp_name, stim).replace('onoffsteps_', ''))
        plf.spineless(ax)
    plt.suptitle(f'Distribution of On-Off Indices for {exp_name}')
    plt.subplots_adjust(top=.95)
    plt.xlabel('On-Off index')
    plt.savefig(os.path.join(exp_dir, 'data_analysis', 'onoffindex_dist.svg'),
                format='svg',
                dpi=300)
    plt.close()
Example #9
0
"""
Created on Tue Feb 20 16:00:51 2018

@author: ycan


This script was used to play around and optimize stripeflicker_SVD
"""
import matplotlib.pyplot as plt
from matplotlib import transforms
import numpy as np
import analysis_scripts as asc
import iofuncs as iof
import plotfuncs as plf

data = iof.load('20180207', 12)
stas = data['stas']
clusters = data['clusters']
clusterids = plf.clusters_to_ids(clusters)

choose = 4
clusterids = [clusterids[choose]]
stas = [stas[choose]]


def component(i, u, s, v):
    #    v = np.dot(v, np.diag(s))
    c = np.outer(u[:, i], v[i, :]) * s[i]
    return c

Example #10
0
w = np.zeros(nbases)
w[-5] = -.3
w[-4] = .5

#filt = make_filter(t, w)

#plt.plot(t, filt)
#plt.show()

#%%
import iofuncs as iof
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

data = iof.load('20180710', 1)
stas = data['stas']
bars = np.arange(nbases)
barkw = {'width': 0.2}

for i, sta in enumerate(stas):

    timev = np.arange(0, data['frame_duration'] * data['filter_length'],
                      data['frame_duration'])
    w0 = np.zeros(nbases)

    popt, pcov = curve_fit(make_filter, timev, sta, p0=w0, bounds=[-2, 2])

    fig = plt.figure(figsize=(8, 3.5))
    ax1 = plt.subplot(1, 2, 1)
    ax2 = plt.subplot(1, 2, 2)
@author: ycan
"""
import warnings
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np

import gaussfitter as gfit
import iofuncs as iof
import analysis_scripts as asc
import miscfuncs as msc

exp = '20180710'
sorted_stimuli = asc.stimulisorter(exp)
checker = sorted_stimuli['frozennoise'][0]
data = iof.load(exp, checker)
parameters = asc.read_parameters(exp, checker)

stas = data['stas']
max_inds = data['max_inds']

i = 0
sta = stas[i]
max_i = max_inds[i]
bound = 1.5

#%%
def fitgaussian(sta, f_size=10):
    max_i = np.unravel_index(np.argmax(np.abs(sta)), sta.shape)
    try:
        sta, max_i_cut = msc.cut_around_center(sta, max_i, f_size)
def stripesurround(exp_name, stimnrs):
    exp_dir = iof.exp_dir_fixer(exp_name)

    if isinstance(stimnrs, int):
        stimnrs = [stimnrs]

    for stimnr in stimnrs:
        data = iof.load(exp_name, stimnr)

        _, metadata = asc.read_spikesheet(exp_dir)
        px_size = metadata['pixel_size(um)']

        clusters = data['clusters']
        stas = data['stas']
        max_inds = data['max_inds']
        filter_length = data['filter_length']
        stx_w = data['stx_w']
        exp_name = data['exp_name']
        stimname = data['stimname']
        frame_duration = data['frame_duration']
        quals = data['quals']

        clusterids = plf.clusters_to_ids(clusters)

        fsize = int(700 / (stx_w * px_size))
        t = np.arange(filter_length) * frame_duration * 1000
        vscale = fsize * stx_w * px_size

        #%%
        cs_inds = np.empty(clusters.shape[0])
        polarities = np.empty(clusters.shape[0])

        savepath = os.path.join(exp_dir, 'data_analysis', stimname)

        for i in range(clusters.shape[0]):
            sta = stas[i]
            max_i = max_inds[i]

            sta, max_i = msc.cutstripe(sta, max_i, fsize * 2)
            plt.figure(figsize=(12, 10))
            ax = plt.subplot(121)
            plf.stashow(sta, ax)

            # Isolate the time point from which the fit will
            # be obtained
            fitv = sta[:, max_i[1]]
            # Make a space vector
            s = np.arange(fitv.shape[0])

            if np.max(fitv) != np.max(np.abs(fitv)):
                onoroff = -1
            else:
                onoroff = 1
            polarities[i] = onoroff
            # Determine the peak values for center and surround
            # to give as initial parameters for curve fitting
            centerpeak = -onoroff * np.max(fitv * onoroff)
            surroundpeak = -onoroff * np.max(fitv * -onoroff)

            # Define initial guesses for the center and surround gaussians
            # First set of values are for center, second for surround.
            p_initial = [centerpeak, max_i[0], 2, surroundpeak, max_i[0], 4]
            bounds = ([0, -np.inf, -np.inf, 0, -np.inf, -np.inf], np.inf)

            try:
                popt, _ = curve_fit(centersurround_onedim,
                                    s,
                                    fitv,
                                    p0=p_initial,
                                    bounds=bounds)
            except ValueError as e:
                if str(e) == "`x0` is infeasible.":
                    print(e)
                    popt, _ = curve_fit(onedgauss,
                                        s,
                                        onoroff * fitv,
                                        p0=p_initial[:3])
                    popt = np.append(popt, [0, popt[1], popt[2]])
                else:
                    raise
            fit = centersurround_onedim(s, *popt)

            # Avoid dividing by zero when calculating center-surround index
            if popt[3] > 0:
                csi = popt[0] / popt[3]
            else:
                csi = 0
            cs_inds[i] = csi
            ax = plt.subplot(122)
            plf.spineless(ax)
            ax.set_yticks([])

            # We need to flip the vertical axis to match
            # with the STA next to it
            plt.plot(onoroff * fitv, -s, label='Data')
            plt.plot(onoroff * fit, -s, label='Fit')
            plt.axvline(0, linestyle='dashed', alpha=.5)
            plt.title(f'Center: a: {popt[0]:4.2f}, μ: {popt[1]:4.2f},' +
                      f' σ: {popt[2]:4.2f}\n' +
                      f'Surround: a: {popt[3]:4.2f}, μ: {popt[4]:4.2f},' +
                      f' σ: {popt[5]:4.2f}' + f'\n CS index: {csi:4.2f}')
            plt.subplots_adjust(top=.82)
            plt.suptitle(f'{exp_name}\n{stimname}\n{clusterids[i]}')
            os.makedirs(os.path.join(savepath, 'stripesurrounds'),
                        exist_ok=True)
            plt.savefig(
                os.path.join(savepath, 'stripesurrounds',
                             clusterids[i] + '.svg'))
            plt.close()

        data.update({'cs_inds': cs_inds, 'polarities': polarities})
        np.savez(os.path.join(savepath, f'{stimnr}_data.npz'), **data)
"""

"""

import numpy as np
import matplotlib.pyplot as plt

import pycorrelate

import iofuncs as iof
import analysis_scripts as asc

exp = '20180710'
stim = 8

data = iof.load(exp, stim)
clusters = data['clusters']

allspikes = data['all_spikes']

#plt.xcorr(allspikes[0, :])


def corr(x1, x2=None, window=200):
    if x2 is None:
        x2 = x1
    assert x1.shape == x2.shape
    mid = x1.shape[0]
    # This is super slow,
    out = np.correlate(x1, x2, 'full')[mid - window - 1:mid + window]
    return out
Example #14
0
 def read_datafile(self):
     return iof.load(self.exp, self.stimnr)
Example #15
0
def plotstripestas(exp_name, stim_nrs):
    """
    Plot and save all the STAs from multiple stripe flicker stimuli.
    """
    exp_dir = iof.exp_dir_fixer(exp_name)

    _, metadata = asc.read_spikesheet(exp_dir)
    px_size = metadata['pixel_size(um)']

    if isinstance(stim_nrs, int):
        stim_nrs = [stim_nrs]
    elif len(stim_nrs) == 0:
        return

    for stim_nr in stim_nrs:
        data = iof.load(exp_name, stim_nr)

        clusters = data['clusters']
        stas = data['stas']
        filter_length = data['filter_length']
        stx_w = data['stx_w']
        exp_name = data['exp_name']
        stimname = data['stimname']
        frame_duration = data['frame_duration']
        quals = data['quals']

        clusterids = plf.clusters_to_ids(clusters)

        # Determine frame size so that the total frame covers
        # an area large enough i.e. 2*700um
        t = np.arange(filter_length) * frame_duration * 1000
        vscale = int(stas[0].shape[0] * stx_w * px_size / 1000)
        for i in range(clusters.shape[0]):
            sta = stas[i]

            vmax = np.max(np.abs(sta))
            vmin = -vmax
            plt.figure(figsize=(6, 15))
            ax = plt.subplot(111)
            im = ax.imshow(sta,
                           cmap='RdBu',
                           vmin=vmin,
                           vmax=vmax,
                           extent=[0, t[-1], -vscale, vscale],
                           aspect='auto')
            plt.xlabel('Time [ms]')
            plt.ylabel('Distance [mm]')

            plf.spineless(ax)
            plf.colorbar(im, ticks=[vmin, 0, vmax], format='%.2f', size='2%')
            plt.suptitle(f'{exp_name}\n{stimname}\n'
                         f'{clusterids[i]} Rating: {clusters[i][2]}\n'
                         f'STA quality: {quals[i]:4.2f}')
            plt.subplots_adjust(top=.90)
            savepath = os.path.join(exp_dir, 'data_analysis', stimname, 'STAs')
            if not os.path.isdir(savepath):
                os.makedirs(savepath, exist_ok=True)
            plt.savefig(os.path.join(savepath, clusterids[i] + '.svg'),
                        bbox_inches='tight')
            plt.close()
        print(f'Plotting of {stimname} completed.')
Example #16
0

def cutstripe(sta, max_i, fsize):
    if max_i[0] - fsize <= 0 or max_i[0] + fsize > sta.shape[0]:
        raise ValueError('Cutting outside the STA range.')
    sta_r = sta[max_i[0]-fsize:max_i[0]+fsize+1, :]
    max_i_r = np.append(fsize, max_i[-1])
    return sta_r, max_i_r


exp_name = '20171122'
stripes = [8, 9, 10, 11]

exp_dir = iof.exp_dir_fixer(exp_name)

data = iof.load(exp_name, stripes[0])


_, metadata = asc.read_spikesheet(exp_dir)
px_size = metadata['pixel_size(um)']
exp_name = data['exp_name']
stx_w = data['stx_w']


clusters = data['clusters']
clusterids = plf.clusters_to_ids(clusters)

fsize = int(700/(stx_w*px_size))
vscale = fsize * stx_w*px_size

Created on Wed Jan 31 10:43:08 2018

@author: ycan

Compare the polarities (onoroff, based on absmax of sta,
used to determine whether to flip for fitting and plotting) from
stripesurround with onoffindex from onoffsteps stimulus.

It is expected that most should agree but it is not the case actually.
No further action is taken so far about it.
"""
import iofuncs as iof
import matplotlib.pyplot as plt

exp_name = '20180124'
pairs = [(3, 6), (8, 12)]
conditions = ['Low', 'High']

for i, pair in enumerate(pairs):
    onoffbias = iof.load(exp_name, pair[0])['onoffbias']
    data = iof.load(exp_name, pair[1])
    quals = data['quals']
    cs_inds = data['cs_inds']
    polarities = data['polarities']
    plt.scatter(onoffbias, polarities, label=conditions[i])
    plt.xlabel('On-Off bias')
    plt.ylabel('Polarity index')

    #    plt.ylabel('Center-Surround index')
    plt.legend()
Example #18
0
def csindexchange(exp_name, onoffcutoff=.5, qualcutoff=9):
    """
    Plots the change in center surround indexes in different light
    levels. Also classifies based on ON-OFF index from the onoffsteps
    stimulus at the matching light level.
    """

    # For now there are only three experiments with the
    # different light levels and the indices of stimuli
    # are different. To automate it will be tricky and
    # ROI is just not enough to justify; so they are
    # hard coded.
    if '20180124' in exp_name or '20180207' in exp_name:
        stripeflicker = [6, 12, 17]
        onoffs = [3, 8, 14]
    elif '20180118' in exp_name:
        stripeflicker = [7, 14, 19]
        onoffs = [3, 10, 16]

    exp_dir = iof.exp_dir_fixer(exp_name)
    exp_name = os.path.split(exp_dir)[-1]
    clusternr = asc.read_spikesheet(exp_name)[0].shape[0]

    # Collect all CS indices, on-off indices and quality scores
    csinds = np.zeros((3, clusternr))
    quals = np.zeros((3, clusternr))

    onoffinds = np.zeros((3, clusternr))
    for i, stim in enumerate(onoffs):
        onoffinds[i, :] = iof.load(exp_name, stim)['onoffbias']

    for i, stim in enumerate(stripeflicker):
        data = iof.load(exp_name, stim)
        quals[i, :] = data['quals']
        csinds[i, :] = data['cs_inds']

    csinds_f = np.copy(csinds)
    quals_f = np.copy(quals)
    onoffbias_f = np.copy(onoffinds)

    # Filter them according to the quality cutoff value
    # and set excluded ones to NaN
    for j in range(quals.shape[1]):
        if not np.all(quals[:, j] > qualcutoff):
            quals_f[:, j] = np.nan
            csinds_f[:, j] = np.nan
            onoffbias_f[:, j] = np.nan

    # Define the color for each point depending on each cell's ON-OFF index
    # by appending the color name in an array.
    colors = []
    for j in range(onoffbias_f.shape[1]):
        if np.all(onoffbias_f[:, j] > onoffcutoff):
            # If it stays ON througout
            colors.append('blue')
        elif np.all(onoffbias_f[:, j] < -onoffcutoff):
            # If it stays OFF throughout
            colors.append('red')
        elif (np.all(onoffcutoff > onoffbias_f[:, j])
              and np.all(onoffbias_f[:, j] > -onoffcutoff)):
            # If it's ON-OFF throughout
            colors.append('black')
        else:
            colors.append('white')

    scatterkwargs = {'c': colors, 'alpha': .6, 'linewidths': 0}

    colorcategories = ['blue', 'red', 'black']
    colorlabels = ['ON', 'OFF', 'ON-OFF']

    # Create an array for all the colors to use with plt.legend()
    patches = []
    for color, label in zip(colorcategories, colorlabels):
        patches.append(mpatches.Patch(color=color, label=label))

    x = [np.nanmin(csinds_f), np.nanmax(csinds_f)]

    plt.figure(figsize=(12, 6))
    ax1 = plt.subplot(121)
    plt.legend(handles=patches, fontsize='small')
    plt.scatter(csinds_f[0, :], csinds_f[1, :], **scatterkwargs)
    plt.plot(x, x, 'r--', alpha=.5)
    plt.xlabel('Low 1')
    plt.ylabel('High')

    ax1.set_aspect('equal')
    plf.spineless(ax1)

    ax2 = plt.subplot(122)
    plt.scatter(csinds_f[0, :], csinds_f[2, :], **scatterkwargs)
    plt.plot(x, x, 'r--', alpha=.5)
    plt.xlabel('Low 1')
    plt.ylabel('Low 2')
    ax2.set_aspect('equal')
    plf.spineless(ax2)

    plt.suptitle(f'Center-Surround Index Change\n{exp_name}')
    plt.text(.8,
             -0.1,
             f'qualcutoff:{qualcutoff} onoffcutoff:{onoffcutoff}',
             fontsize='small',
             transform=ax2.transAxes)
    plotsave = os.path.join(exp_dir, 'data_analysis', 'csinds')
    plt.savefig(plotsave + '.svg', format='svg', bbox_inches='tight')
    plt.savefig(plotsave + '.pdf', format='pdf', bbox_inches='tight')
    plt.show()
    plt.close()
Example #19
0
def plot_checker_stas(exp_name, stim_nr, filename=None):
    """
    Plot and save all STAs from checkerflicker analysis. The plots
    will be saved in a new folder called STAs under the data analysis
    path of the stimulus.

    <exp_dir>/data_analysis/<stim_nr>_*/<stim_nr>_data.h5 file is
    used by default. If a different file is to be used, filename
    should be supplied.
    """

    from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar

    exp_dir = iof.exp_dir_fixer(exp_name)
    stim_nr = str(stim_nr)
    if filename:
        filename = str(filename)

    _, metadata = asc.read_spikesheet(exp_dir)
    px_size = metadata['pixel_size(um)']

    if not filename:
        savefolder = 'STAs'
        label = ''
    else:
        label = filename.strip('.npz')
        savefolder = 'STAs_' + label

    data = iof.load(exp_name, stim_nr, fname=filename)

    clusters = data['clusters']
    stas = data['stas']
    filter_length = data['filter_length']
    stx_h = data['stx_h']
    exp_name = data['exp_name']
    stimname = data['stimname']

    for j in range(clusters.shape[0]):
        a = stas[j]
        subplot_arr = plf.numsubplots(filter_length)
        sta_max = np.max(np.abs([np.max(a), np.min(a)]))
        sta_min = -sta_max
        plt.figure(dpi=250)
        for i in range(filter_length):
            ax = plt.subplot(subplot_arr[0], subplot_arr[1], i + 1)
            im = ax.imshow(a[:, :, i],
                           vmin=sta_min,
                           vmax=sta_max,
                           cmap=iof.config('colormap'))
            ax.set_aspect('equal')
            plt.axis('off')
            if i == 0:
                scalebar = AnchoredSizeBar(ax.transData,
                                           10,
                                           '{} µm'.format(10 * stx_h *
                                                          px_size),
                                           'lower left',
                                           pad=0,
                                           color='k',
                                           frameon=False,
                                           size_vertical=1)
                ax.add_artist(scalebar)
            if i == filter_length - 1:
                plf.colorbar(im, ticks=[sta_min, 0, sta_max], format='%.2f')
        plt.suptitle('{}\n{}\n'
                     '{:0>3}{:0>2} Rating: {}'.format(exp_name,
                                                      stimname + label,
                                                      clusters[j][0],
                                                      clusters[j][1],
                                                      clusters[j][2]))

        savepath = os.path.join(
            exp_dir, 'data_analysis', stimname, savefolder,
            '{:0>3}{:0>2}'.format(clusters[j][0], clusters[j][1]))

        os.makedirs(os.path.split(savepath)[0], exist_ok=True)

        plt.savefig(savepath + '.png', bbox_inches='tight')
        plt.close()
    print(f'Plotted checkerflicker STA for {stimname}')
Example #20
0
@author: ycan

Compare on off bias change in different light conditions.
"""
import iofuncs as iof
import os
import matplotlib.pyplot as plt
import plotfuncs as plf
import numpy as np

exp_name = '20180124'
exp_dir = iof.exp_dir_fixer(exp_name)

onoffinds = np.zeros((3, 30))
for i, stim in enumerate([3, 8, 14]):
    onoffinds[i, :] = iof.load(exp_name, stim)['onoffbias']

#%%
labels = ['1_low', '2_high', '3_low']
plt.figure(figsize=(12, 10))
ax = plt.subplot(111)
plt.plot(labels, onoffinds)
plt.ylabel('On-Off Bias')
plt.title('On-Off Bias Change')
plf.spineless(ax)

plotsave = os.path.join(exp_dir, 'data_analysis', 'onoffbias')

#plt.savefig(plotsave+'.svg', format = 'svg', bbox_inches='tight')
#plt.savefig(plotsave+'.pdf', format = 'pdf', bbox_inches='tight')
plt.show()
Example #21
0
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Feb  1 11:29:37 2018

@author: ycan
"""

import plotfuncs as plf
import matplotlib.pyplot as plt
import miscfuncs as msc
import iofuncs as iof

data = iof.load('20180124', 12)
index = 5
sta = data['stas'][index]
max_i = data['max_inds'][index]
sta, max_i = msc.cutstripe(sta, max_i, 30)

a = 'Accent, Accent_r, Blues, Blues_r, BrBG, BrBG_r, BuGn, BuGn_r, BuPu, BuPu_r, CMRmap, CMRmap_r, Dark2, Dark2_r, GnBu, GnBu_r, Greens, Greens_r, Greys, Greys_r, OrRd, OrRd_r, Oranges, Oranges_r, PRGn, PRGn_r, Paired, Paired_r, Pastel1, Pastel1_r, Pastel2, Pastel2_r, PiYG, PiYG_r, PuBu, PuBuGn, PuBuGn_r, PuBu_r, PuOr, PuOr_r, PuRd, PuRd_r, Purples, Purples_r, RdBu, RdBu_r, RdGy, RdGy_r, RdPu, RdPu_r, RdYlBu, RdYlBu_r, RdYlGn, RdYlGn_r, Reds, Reds_r, Set1, Set1_r, Set2, Set2_r, Set3, Set3_r, Spectral, Spectral_r, Vega10, Vega10_r, Vega20, Vega20_r, Vega20b, Vega20b_r, Vega20c, Vega20c_r, Wistia, Wistia_r, YlGn, YlGnBu, YlGnBu_r, YlGn_r, YlOrBr, YlOrBr_r, YlOrRd, YlOrRd_r, afmhot, afmhot_r, autumn, autumn_r, binary, binary_r, bone, bone_r, brg, brg_r, bwr, bwr_r, cool, cool_r, coolwarm, coolwarm_r, copper, copper_r, cubehelix, cubehelix_r, flag, flag_r, gist_earth, gist_earth_r, gist_gray, gist_gray_r, gist_heat, gist_heat_r, gist_ncar, gist_ncar_r, gist_rainbow, gist_rainbow_r, gist_stern, gist_stern_r, gist_yarg, gist_yarg_r, gnuplot, gnuplot2, gnuplot2_r, gnuplot_r, gray, gray_r, hot, hot_r, hsv, hsv_r, inferno, inferno_r, jet, jet_r, magma, magma_r, nipy_spectral, nipy_spectral_r, ocean, ocean_r, pink, pink_r, plasma, plasma_r, prism, prism_r, rainbow, rainbow_r, seismic, seismic_r, spectral, spectral_r, spring, spring_r, summer, summer_r, tab10, tab10_r, tab20, tab20_r, tab20b, tab20b_r, tab20c, tab20c_r, terrain, terrain_r, viridis, viridis_r, winter, winter_r'
b = a.split(',')
c = [i.strip(' ') for i in b if not i.endswith('_r')]

c = ['bwr_r', 'RdBu', 'seismic_r', 'bwr', 'RdBu_r', 'seismic']
dims = plf.numsubplots(len(c))
plt.figure(figsize=(20, 20))
for i, cm in enumerate(c):
    ax = plt.subplot(dims[0], dims[1], i + 1)
    im = plf.stashow(sta, ax, cmap=cm, ticks=[])
    plt.axis('off')
    im.axes.get_xaxis().set_visible(False)
Example #22
0
def plotcheckersurround(exp_name, stim_nr, filename=None, spikecutoff=1000,
                        ratingcutoff=4, staqualcutoff=0, inner_b=2,
                        outer_b=4):

    """
    Divides into center and surround by fitting 2D Gaussian, and plot
    temporal components.

    spikecutoff:
        Minimum number of spikes to include.

    ratingcutoff:
        Minimum spike sorting rating to include.

    staqualcutoff:
        Minimum STA quality (as measured by z-score) to include.

    inner_b:
        Defined limit between receptive field center and surround
        in units of sigma.

    outer_b:
        Defined limit of the end of receptive field surround.
    """

    exp_dir = iof.exp_dir_fixer(exp_name)
    stim_nr = str(stim_nr)
    if filename:
        filename = str(filename)

    if not filename:
        savefolder = 'surroundplots'
        label = ''
    else:
        label = filename.strip('.npz')
        savefolder = 'surroundplots_' + label

    _, metadata = asc.read_spikesheet(exp_name)
    px_size = metadata['pixel_size(um)']

    data = iof.load(exp_name, stim_nr, fname=filename)

    clusters = data['clusters']
    stas = data['stas']
    stx_h = data['stx_h']
    exp_name = data['exp_name']
    stimname = data['stimname']
    max_inds = data['max_inds']
    frame_duration = data['frame_duration']
    filter_length = data['filter_length']
    quals = data['quals'][-1, :]

    spikenrs = data['spikenrs']

    c1 = np.where(spikenrs > spikecutoff)[0]
    c2 = np.where(clusters[:, 2] <= ratingcutoff)[0]
    c3 = np.where(quals > staqualcutoff)[0]

    choose = [i for i in range(clusters.shape[0]) if ((i in c1) and
                                                      (i in c2) and
                                                      (i in c3))]
    clusters = clusters[choose]
    stas = list(np.array(stas)[choose])
    max_inds = list(np.array(max_inds)[choose])

    clusterids = plf.clusters_to_ids(clusters)

    t = np.arange(filter_length)*frame_duration*1000

    # Determine frame size so that the total frame covers
    # an area large enough i.e. 2*700um
    f_size = int(700/(stx_h*px_size))

    del data

    for i in range(clusters.shape[0]):

        sta_original = stas[i]
        max_i_original = max_inds[i]

        try:
            sta, max_i = mf.cut_around_center(sta_original,
                                              max_i_original, f_size)
        except ValueError:
            continue

        fit_frame = sta[:, :, max_i[2]]

        if np.max(fit_frame) != np.max(np.abs(fit_frame)):
            onoroff = -1
        else:
            onoroff = 1



        Y, X = np.meshgrid(np.arange(fit_frame.shape[1]),
                           np.arange(fit_frame.shape[0]))

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore',
                                    '.*divide by zero*.', RuntimeWarning)
            pars = gfit.gaussfit(fit_frame*onoroff)
            f = gfit.twodgaussian(pars)
            Z = f(X, Y)

        # Correcting for Mahalonobis dist.
        with warnings.catch_warnings():
            warnings.filterwarnings('ignore',
                                    '.*divide by zero*.', RuntimeWarning)
            Zm = np.log((Z-pars[0])/pars[1])
        Zm[np.isinf(Zm)] = np.nan
        Zm = np.sqrt(Zm*-2)

        ax = plt.subplot(1, 2, 1)

        plf.stashow(fit_frame, ax)
        ax.set_aspect('equal')

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', category=UserWarning)
            warnings.filterwarnings('ignore', '.*invalid value encountered*.')
            ax.contour(Y, X, Zm, [inner_b, outer_b],
                       cmap=plf.RFcolormap(('C0', 'C1')))

        barsize = 100/(stx_h*px_size)
        scalebar = AnchoredSizeBar(ax.transData,
                                   barsize, '100 µm',
                                   'lower left',
                                   pad=1,
                                   color='k',
                                   frameon=False,
                                   size_vertical=.2)
        ax.add_artist(scalebar)

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore',
                                    '.*invalid value encountered in*.',
                                    RuntimeWarning)
            center_mask = np.logical_not(Zm < inner_b)
            center_mask_3d = np.broadcast_arrays(sta,
                                                 center_mask[..., None])[1]
            surround_mask = np.logical_not(np.logical_and(Zm > inner_b,
                                                          Zm < outer_b))
            surround_mask_3d = np.broadcast_arrays(sta,
                                                   surround_mask[..., None])[1]

        sta_center = np.ma.array(sta, mask=center_mask_3d)
        sta_surround = np.ma.array(sta, mask=surround_mask_3d)

        sta_center_temporal = np.mean(sta_center, axis=(0, 1))
        sta_surround_temporal = np.mean(sta_surround, axis=(0, 1))

        ax1 = plt.subplot(1, 2, 2)
        l1 = ax1.plot(t, sta_center_temporal,
                      label='Center\n(<{}σ)'.format(inner_b),
                      color='C0')
        sct_max = np.max(np.abs(sta_center_temporal))
        ax1.set_ylim(-sct_max, sct_max)
        ax2 = ax1.twinx()
        l2 = ax2.plot(t, sta_surround_temporal,
                      label='Surround\n({}σ<x<{}σ)'.format(inner_b, outer_b),
                      color='C1')
        sst_max = np.max(np.abs(sta_surround_temporal))
        ax2.set_ylim(-sst_max, sst_max)
        plf.spineless(ax1)
        plf.spineless(ax2)
        ax1.tick_params('y', colors='C0')
        ax2.tick_params('y', colors='C1')
        plt.xlabel('Time[ms]')
        plt.axhline(0, linestyle='dashed', linewidth=1)

        lines = l1+l2
        labels = [line.get_label() for line in lines]
        plt.legend(lines, labels, fontsize=7)
        plt.title('Temporal components')
        plt.suptitle(f'{exp_name}\n{stimname}\n{clusterids[i]}')

        plt.subplots_adjust(wspace=.5, top=.85)

        plotpath = os.path.join(exp_dir, 'data_analysis',
                                stimname, savefolder)
        if not os.path.isdir(plotpath):
            os.makedirs(plotpath, exist_ok=True)

        plt.savefig(os.path.join(plotpath, clusterids[i])+'.svg',
                    format='svg', dpi=300)
        plt.close()
    print(f'Plotted checkerflicker surround for {stimname}')
def csindexchange(exp_name, onoffcutoff=.5, qualcutoff=qualcutoff):
    """
    Returns in center surround indexes and ON-OFF classfication in
    mesopic and photopic light levels.
    """
    # For now there are only three experiments with the
    # different light levels and the indices of stimuli
    # are different. To automate it will be tricky and
    # ROI is just not enough to justify; so they are
    # hard coded.
    if '20180124' in exp_name or '20180207' in exp_name:
        stripeflicker = [6, 17]
        onoffs = [3, 14]
    elif '20180118' in exp_name:
        stripeflicker = [7, 19]
        onoffs = [3, 16]

    exp_dir = iof.exp_dir_fixer(exp_name)
    exp_name = os.path.split(exp_dir)[-1]
    clusternr = asc.read_spikesheet(exp_name)[0].shape[0]

    # Collect all CS indices, on-off indices and quality scores
    csinds = np.zeros((2, clusternr))
    quals = np.zeros((2, clusternr))

    onoffinds = np.zeros((2, clusternr))
    for i, stim in enumerate(onoffs):
        onoffinds[i, :] = iof.load(exp_name, stim)['onoffbias']

    for i, stim in enumerate(stripeflicker):
        data = iof.load(exp_name, stim)
        quals[i, :] = data['quals']
        csinds[i, :] = data['cs_inds']

    csinds_f = np.copy(csinds)
    quals_f = np.copy(quals)
    onoffbias_f = np.copy(onoffinds)

    # Filter them according to the quality cutoff value
    # and set excluded ones to NaN

    for j in range(quals.shape[1]):
        if not np.all(quals[:, j] > qualcutoff):
            quals_f[:, j] = np.nan
            csinds_f[:, j] = np.nan
            onoffbias_f[:, j] = np.nan

    # Calculate the change of polarity for each cell
    # np.diff gives the high-low value
    biaschange = np.diff(onoffbias_f, axis=0)[0]

    # Define the color for each point depending on each cell's ON-OFF index
    # by appending the color name in an array.
    colors = []
    for j in range(onoffbias_f.shape[1]):
        if np.all(onoffbias_f[:, j] > onoffcutoff):
            # If it stays ON througout
            colors.append(colorcategories[0])
        elif np.all(onoffbias_f[:, j] < -onoffcutoff):
            # If it stays OFF throughout
            colors.append(colorcategories[1])
        elif (np.all(onoffcutoff > onoffbias_f[:, j])
              and np.all(onoffbias_f[:, j] > -onoffcutoff)):
            # If it's ON-OFF throughout
            colors.append(colorcategories[2])
        elif biaschange[j] > 0:
            # Increasing polarity
            # If it's not consistent in any category and
            # polarity change is positive
            colors.append(colorcategories[3])
        elif biaschange[j] < 0:
            # Decreasing polarity
            colors.append(colorcategories[4])
        else:
            colors.append('yellow')

    return csinds_f, colors, onoffbias_f, quals_f
Example #24
0
def stripesurround_SVD(exp_name, stimnrs, nrcomponents=5):
    """
    nrcomponents:
        first N components of singular value decomposition (SVD)
        will be used to reduce noise.
    """
    exp_dir = iof.exp_dir_fixer(exp_name)

    if isinstance(stimnrs, int):
        stimnrs = [stimnrs]

    for stimnr in stimnrs:
        data = iof.load(exp_name, stimnr)

        _, metadata = asc.read_spikesheet(exp_dir)
        px_size = metadata['pixel_size(um)']

        clusters = data['clusters']
        stas = data['stas']
        max_inds = data['max_inds']
        filter_length = data['filter_length']
        stx_w = data['stx_w']
        exp_name = data['exp_name']
        stimname = data['stimname']
        frame_duration = data['frame_duration']
        quals = data['quals']

        # Record which clusters are ignored during analysis
        try:
            included = data['included']
        except KeyError:
            included = [True] * clusters.shape[0]

        # Average STA values 100 ms around the brightest frame to
        # minimize noise
        cut_time = int(100 / (frame_duration * 1000) / 2)

        # Tolerance for distance between center and surround
        # distributions 60 μm
        dtol = int((60 / px_size) / 2)

        clusterids = plf.clusters_to_ids(clusters)

        fsize = int(700 / (stx_w * px_size))
        t = np.arange(filter_length) * frame_duration * 1000
        vscale = fsize * stx_w * px_size

        cs_inds = np.empty(clusters.shape[0])
        polarities = np.empty(clusters.shape[0])

        savepath = os.path.join(exp_dir, 'data_analysis', stimname)

        for i in range(clusters.shape[0]):
            sta = stas[i]
            max_i = max_inds[i]

            # From this point on, use the low-rank approximation
            # version
            sta_reduced = sumcomponent(nrcomponents, sta)

            try:
                sta_reduced, max_i = msc.cutstripe(sta_reduced, max_i,
                                                   fsize * 2)
            except ValueError as e:
                if str(e) == 'Cutting outside the STA range.':
                    included[i] = False
                    continue
                else:
                    print(f'Error while analyzing {stimname}\n' +
                          f'Index:{i}    Cluster:{clusterids[i]}')
                    raise

            # Isolate the time point from which the fit will
            # be obtained
            if max_i[1] < cut_time:
                max_i[1] = cut_time + 1
            fitv = np.mean(sta_reduced[:, max_i[1] - cut_time:max_i[1] +
                                       cut_time + 1],
                           axis=1)

            # Make a space vector
            s = np.arange(fitv.shape[0])

            if np.max(fitv) != np.max(np.abs(fitv)):
                onoroff = -1
            else:
                onoroff = 1
            polarities[i] = onoroff
            # Determine the peak values for center and surround
            # to give as initial parameters for curve fitting
            centerpeak = onoroff * np.max(fitv * onoroff)
            surroundpeak = onoroff * np.max(fitv * -onoroff)

            # Define initial guesses for the center and surround gaussians
            # First set of values are for center, second for surround.
            p_initial = [centerpeak, max_i[0], 2, surroundpeak, max_i[0], 8]
            if onoroff == 1:
                bounds = ([0, -np.inf, -np.inf, 0, max_i[0] - dtol, 4], [
                    np.inf, np.inf, np.inf, np.inf, max_i[0] + dtol, 20
                ])
            elif onoroff == -1:
                bounds = ([
                    -np.inf, -np.inf, -np.inf, -np.inf, max_i[0] - dtol, 4
                ], [0, np.inf, np.inf, 0, max_i[0] + dtol, 20])

            try:
                popt, _ = curve_fit(centersurround_onedim,
                                    s,
                                    fitv,
                                    p0=p_initial,
                                    bounds=bounds)
            except (ValueError, RuntimeError) as e:
                er = str(e)
                if (er == "`x0` is infeasible."
                        or er.startswith("Optimal parameters not found")):
                    popt, _ = curve_fit(onedgauss, s, fitv, p0=p_initial[:3])
                    popt = np.append(popt, [0, popt[1], popt[2]])
                elif er == "array must not contain infs or NaNs":
                    included[i] = False
                    continue
                else:
                    print(f'Error while analyzing {stimname}\n' +
                          f'Index:{i}    Cluster:{clusterids[i]}')
                    import pdb
                    pdb.set_trace()
                    raise

            fit = centersurround_onedim(s, *popt)
            popt[0] = popt[0] * onoroff
            popt[3] = popt[3] * onoroff

            csi = popt[3] / popt[0]
            cs_inds[i] = csi

            plt.figure(figsize=(10, 9))
            ax = plt.subplot(121)
            plf.stashow(sta_reduced, ax, extent=[0, t[-1], -vscale, vscale])
            ax.set_xlabel('Time [ms]')
            ax.set_ylabel('Distance [µm]')
            ax.set_title(f'Using first {nrcomponents} components of SVD',
                         fontsize='small')

            ax = plt.subplot(122)
            plf.spineless(ax)
            ax.set_yticks([])
            # We need to flip the vertical axis to match
            # with the STA next to it
            plt.plot(onoroff * fitv, -s, label='Data')
            plt.plot(onoroff * fit, -s, label='Fit')
            plt.axvline(0, linestyle='dashed', alpha=.5)
            plt.title(f'Center: a: {popt[0]:4.2f}, μ: {popt[1]:4.2f},' +
                      f' σ: {popt[2]:4.2f}\n' +
                      f'Surround: a: {popt[3]:4.2f}, μ: {popt[4]:4.2f},' +
                      f' σ: {popt[5]:4.2f}' + f'\n CS index: {csi:4.2f}')
            plt.subplots_adjust(top=.85)
            plt.suptitle(f'{exp_name}\n{stimname}\n{clusterids[i]} ' +
                         f'Q: {quals[i]:4.2f}')
            os.makedirs(os.path.join(savepath, 'stripesurrounds_SVD'),
                        exist_ok=True)
            plt.savefig(os.path.join(savepath, 'stripesurrounds_SVD',
                                     clusterids[i] + '.svg'),
                        bbox_inches='tight')
            plt.close()

        data.update({
            'cs_inds': cs_inds,
            'polarities': polarities,
            'included': included
        })
        np.savez(os.path.join(savepath, f'{stimnr}_data_SVD.npz'), **data)
        print(f'Surround plotted and saved for {stimname}.')
Example #25
0
def plotcheckersvd(expname, stimnr, filename=None):
    """
    Plot the first two components of SVD analysis.
    """
    if filename:
        filename = str(filename)

    exp_dir = iof.exp_dir_fixer(expname)
    _, metadata = asc.read_spikesheet(exp_dir)
    px_size = metadata['pixel_size(um)']

    if not filename:
        savefolder = 'SVD'
        label = ''
    else:
        label = filename.strip('.npz')
        savefolder = 'SVD_' + label

    data = iof.load(expname, stimnr, filename)

    stas = data['stas']
    max_inds = data['max_inds']
    clusters = data['clusters']
    stx_h = data['stx_h']
    frame_duration = data['frame_duration']
    stimname = data['stimname']
    exp_name = data['exp_name']

    clusterids = plf.clusters_to_ids(clusters)

    # Determine frame size so that the total frame covers
    # an area large enough i.e. 2*700um
    f_size = int(700 / (stx_h * px_size))

    for i in range(clusters.shape[0]):
        sta = stas[i]
        max_i = max_inds[i]

        try:
            sta, max_i = msc.cut_around_center(sta, max_i, f_size=f_size)
        except ValueError:
            continue
        fit_frame = sta[:, :, max_i[2]]

        try:
            sp1, sp2, t1, t2, _, _ = msc.svd(sta)
        # If the STA is noisy (msc.cut_around_center produces an empty array)
        # SVD cannot be calculated, in this case we skip that cluster.
        except np.linalg.LinAlgError:
            continue

        plotthese = [fit_frame, sp1, sp2]

        plt.figure(dpi=200)
        plt.suptitle(f'{exp_name}\n{stimname}\n{clusterids[i]}')
        rows = 2
        cols = 3

        vmax = np.max(np.abs([sp1, sp2]))
        vmin = -vmax

        for j in range(len(plotthese)):
            ax = plt.subplot(rows, cols, j + 1)
            im = plt.imshow(plotthese[j],
                            vmin=vmin,
                            vmax=vmax,
                            cmap=iof.config('colormap'))
            ax.set_aspect('equal')
            plt.xticks([])
            plt.yticks([])
            for child in ax.get_children():
                if isinstance(child, matplotlib.spines.Spine):
                    child.set_color('C{}'.format(j % 3))
                    child.set_linewidth(2)
            if j == 0:
                plt.title('center px')
            elif j == 1:
                plt.title('SVD spatial 1')
            elif j == 2:
                plt.title('SVD spatial 2')
                plf.colorbar(im, ticks=[vmin, 0, vmax], format='%.2f')
                barsize = 100 / (stx_h * px_size)
                scalebar = AnchoredSizeBar(ax.transData,
                                           barsize,
                                           '100 µm',
                                           'lower left',
                                           pad=0,
                                           color='k',
                                           frameon=False,
                                           size_vertical=.3)
                ax.add_artist(scalebar)

        t = np.arange(sta.shape[-1]) * frame_duration * 1000
        plt.subplots_adjust(wspace=0.3, hspace=0)
        ax = plt.subplot(rows, 1, 2)
        plt.plot(t, sta[max_i[0], max_i[1], :], label='center px')
        plt.plot(t, t1, label='Temporal 1')
        plt.plot(t, t2, label='Temporal 2')
        plt.xlabel('Time[ms]')
        plf.spineless(ax, 'trlb')  # Turn off spines using custom function

        plotpath = os.path.join(exp_dir, 'data_analysis', stimname, savefolder)
        if not os.path.isdir(plotpath):
            os.makedirs(plotpath, exist_ok=True)
        plt.savefig(os.path.join(plotpath, clusterids[i] + '.svg'), dpi=300)
        plt.close()
    print(f'Plotted checkerflicker SVD for {stimname}')
Example #26
0
    for i in range(st.nclusters):
        all_spikes[i, :] = st.binnedspiketimes(i)

    stas = np.einsum('abcd,ec->eabd', rw, all_spikes)
    stas /= all_spikes.sum(axis=(-1))[:, np.newaxis, np.newaxis, np.newaxis]

    # Correct for the non-informative parts of the stimulus
    stas = stas - contrast_avg[None, ..., None]

    print(
        f'{msc.timediff(startime)} elapsed for contrast generation and STA calculation'
    )
    #%%
    #    fig1 = plt.figure(1)
    #    fig2 = plt.figure(2)
    data = iof.load(exp, checkerstimnr)
    ckstas = np.array(data['stas'])
    ckstas /= np.nanmax(np.abs(ckstas), axis=0)[np.newaxis, ...]
    ckstas = ckstas[..., ::-1]
    #    imshowkwargs_omb = dict(cmap='RdBu_r', vmin=stas.min(), vmax=stas.max())
    #    imshowkwargs_chk = dict(cmap='RdBu_r', vmin=-np.nanmax(np.abs(ckstas)), vmax=np.nanmax(np.abs(ckstas)))

    #    fig3, axes = plt.subplots(1, 2, num=3)
    #    i = 32
    #    ims = []
    #    for j in range(20):
    #        im_omb = axes[0].imshow(stas[i, :, :, j], **imshowkwargs_omb, animated=True)
    #        im_chk = axes[1].imshow(ckstas[i, :, :, 2*j], **imshowkwargs_chk, animated=True)
    #        ims.append([im_omb, im_chk])
    #
    #    #    plt.show()
Example #27
0
def allfff(exp_name, stim_nrs):
    """
    Plot all of the full field flicker STAs on top of each other
    to see the progression of the cell responses, their firing rates.
    """

    if isinstance(stim_nrs, int) or len(stim_nrs) <= 1:
        print('Multiple full field flicker stimuli expected, '
              'allfff analysis will be skipped.')
        return

    exp_dir = iof.exp_dir_fixer(exp_name)
    exp_name = os.path.split(exp_dir)[-1]

    # Sanity check to ensure we are commparing the same stimuli and parameters
    prev_parameters = {}
    for i in stim_nrs:
        pars = asc.read_parameters(exp_name, i)
        currentfname = pars.pop('filename')
        if len(prev_parameters) == 0:
            prev_parameters = pars
        for k1, k2 in zip(pars.keys(), prev_parameters.keys()):
            if pars[k1] != prev_parameters[k2]:
                raise ValueError(
                    f'Parameters for {currentfname} do not match!\n'
                    f'{k1}:{pars[k1]}\n{k2}:{prev_parameters[k2]}')

    stimnames = []
    for j, stim in enumerate(stim_nrs):
        data = iof.load(exp_name, stim)
        stas = data['stas']
        clusters = data['clusters']
        filter_length = data['filter_length']
        frame_duration = data['frame_duration']
        if j == 0:
            all_stas = np.zeros(
                (clusters.shape[0], filter_length, len(stim_nrs)))
            all_spikenrs = np.zeros((clusters.shape[0], len(stim_nrs)))
        all_stas[:, :, j] = stas
        all_spikenrs[:, j] = data['spikenrs']
        stimnames.append(iof.getstimname(exp_name, stim))

    t = np.linspace(0, frame_duration * filter_length, num=filter_length)
    #%%
    clusterids = plf.clusters_to_ids(clusters)
    for i in range(clusters.shape[0]):
        fig = plt.figure()
        ax1 = plt.subplot(111)
        ax1.plot(t, all_stas[i, :, :])
        ax1.set_xlabel('Time [ms]')
        ax1.legend(stimnames, fontsize='x-small')
        ax2 = fig.add_axes([.65, .15, .2, .2])
        for j in range(len(stim_nrs)):
            ax2.plot(j, all_spikenrs[i, j], 'o')
        ax2.set_ylabel('# spikes', fontsize='small')
        ax2.set_xticks([])
        ax2.patch.set_alpha(0)
        plf.spineless(ax1, 'tr')
        plf.spineless(ax2, 'tr')
        plt.suptitle(f'{exp_name}\n {clusterids[i]}')
        plotpath = os.path.join(exp_dir, 'data_analysis', 'all_fff')
        if not os.path.isdir(plotpath):
            os.makedirs(plotpath, exist_ok=True)
        plt.savefig(os.path.join(plotpath, clusterids[i]) + '.svg',
                    format='svg',
                    dpi=300)
        plt.close()
    print('Plotted full field flicker STAs together from all stimuli.')